Compare commits

..

23 Commits

Author SHA1 Message Date
Zero Rains
bd30b08521 get org_vocab_size from args (#3981) 2025-09-09 15:08:47 +08:00
Divano
1aa16146ba Update requirements.txt (#3915) 2025-09-05 13:51:22 +08:00
ApplEOFDiscord
dac0a00d0f [BugFix] fix max streaming tokens invalid (#3774) (#3856)
* Update serving_chat.py

* Update serving_completion.py

Co-authored-by: ltd0924 <32387785+ltd0924@users.noreply.github.com>
2025-09-03 17:50:29 +08:00
ltd0924
c5591c45df [BugFix] fix max streaming tokens invalid (#3774)
* Update serving_chat.py

* Update serving_completion.py
2025-09-02 21:00:29 +08:00
chen
121ac85d7d fix (#3640) 2025-08-27 14:23:38 +08:00
chen
d233e3c97c [Precision] Change lm_head layer running in float32 (#3596)
* support lm_head fp32 bf16 fp16

* delete print

* code check

* check

* check

* code check

* check

* check
2025-08-26 20:20:06 +08:00
chen
2136990144 [Feature] Add temp_scaled_logprobs and top_p_normalized_logprobs parameters for logits and logprobs post processing (#3536)
* [feature] Add temp_scaled_logprobs and top_p_normalized_logprobs parameters for logits and logprobs post processing

* infer engine support temp_scaled_logprobs and top_p_normalized_logprobs

* code check

* code check

* fix tokenizer.decoder(-1), return 'Invalid Token'

* check seq len time shape

* logprob clip inf

* code check

---------

Co-authored-by: sunlei1024 <sunlei5788@gmail.com>
2025-08-25 14:11:18 +08:00
kevin
b7890cbe8d fix uvicorn multi worker error (#3339) 2025-08-25 11:24:07 +08:00
chenjian
bc388b65c7 [Bug fix] Fix bug in logprob in release 2.0.4 (#3445)
* fix bug for scheduler v0

* Fix logprob in release/2.0.4
2025-08-16 21:13:10 +08:00
Jiang-Jia-Jun
71af0ca04a [BugFix] Fix default log level of paddleformers (#3378) 2025-08-15 18:30:00 +08:00
YuBaoku
d66660a0d1 [CI] fix run_ci error in release/2.0.4 (#3411) 2025-08-14 22:44:17 +08:00
xiaolei373
f0519aec67 feat(log):add_request_and_response_log (#3391)
* feat(log):add_request_and_response_log

* [ci] Retrigger

* [ci] Retrigger
2025-08-14 19:12:42 +08:00
gaoziyuan
1f5983290c fix mapping (#3321) 2025-08-12 16:17:59 +08:00
chenjian
c6a133d573 [Bug fix] Fix block num in scheduler v1 for release2.0.4 (#3314)
* fix bug for scheduler v0

* fix block num setting in scheduler v1

* fix block num setting in scheduler v1

* fix block num setting in scheduler v1

* fix block num setting in scheduler v1

* fix block num setting in scheduler v1
2025-08-11 23:55:45 +08:00
chenjian
4646aff25c fix bug for scheduler v0 (#3307) 2025-08-11 23:55:20 +08:00
chenjian
a84a98b107 fix scheduler bug due to async running (#3293) 2025-08-10 13:54:59 +08:00
chenjian
c208086f61 fix scheduler bug for bs=1 (#3288) 2025-08-09 12:22:12 +08:00
sg263
ce1d4944e7 merge develop trace FD_START (#3253) (#3260)
Co-authored-by: shige <shige@baidu.com>
2025-08-07 16:06:58 +08:00
chenjian
5439fb6336 [Cherry-pick] FIx bug for scheduler V1 (#3167)
* [BUG FIX] Fix bug when preempted request rescheduled (#3080)

* Fix bug when preempted request rescheduled

* Fix bug when preempted request rescheduled

* Fix bug when preempted request rescheduled

* Fix bug for offline inference in scheduler v1 (#3117)
2025-08-04 17:08:12 +08:00
gaoziyuan
a592d17615 support qwen3 name_mapping (#3180) 2025-08-04 16:37:34 +08:00
李泳桦
eca8fc7ca6 [feat] extra parameters are all passed directly via http payload now, or in extra_body if using openai client (#3077)
* [feat] extra parameters are all passed directly via http payload now, or in extra_body if using openai client

* [fix] delete ci test case for enable_thinking

* [fix] add reasoning_parser when server starts

* [doc] update docs related to metadata

* [fix] fix ci consistency test error with reasoning parser

* [fix] cancel enable_thinking default value
2025-07-30 19:25:39 +08:00
李泳桦
0463797fc2 [feat] add disable_chat_template in chat api as a substitute for previous raw_request (#3023)
* [feat] add disable_chat_template in chat api as a substitute for previous raw_request

* [fix] pre-commit code check
2025-07-25 20:57:06 +08:00
Jiang-Jia-Jun
0ab8645fc4 Update setup.py 2025-07-25 10:27:51 +08:00
935 changed files with 15826 additions and 106073 deletions

View File

@@ -2,9 +2,7 @@ name: Codestyle-Check
on:
pull_request:
branches:
- develop
- 'release/*'
branches: ["develop"]
jobs:
pre-commit:
@@ -13,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
env:
PR_ID: ${{ github.event.pull_request.number }}
BRANCH: ${{ github.event.pull_request.base.ref }}
BRANCH: develop
steps:
- name: Cleanup

View File

@@ -1,187 +0,0 @@
name: Accuracy Test
description: "Run Accuracy Tests"
on:
workflow_call:
inputs:
DOCKER_IMAGE:
description: "Build Images"
required: true
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
FASTDEPLOY_ARCHIVE_URL:
description: "URL of the compressed FastDeploy code archive."
required: true
type: string
FASTDEPLOY_WHEEL_URL:
description: "URL of the FastDeploy Wheel."
required: true
type: string
CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
MODEL_CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
jobs:
accuracy_tests:
runs-on: [self-hosted, GPU-h20-1Cards]
timeout-minutes: 60
steps:
- name: Code Prepare
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
run: |
set -x
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
docker pull ${docker_image}
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}*
fi
'
wget -q ${fd_archive_url}
tar -xf FastDeploy.tar.gz
rm -rf FastDeploy.tar.gz
cd FastDeploy
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git log -n 3 --oneline
- name: Run FastDeploy Base Tests
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
run: |
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"
echo "FD_API_PORT=${FD_API_PORT}"
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
echo "DEVICES=${DEVICES}"
echo "========================================================="
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
fi
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
exit 1
fi
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
echo "========================================================="
echo "Ensuring no stale container named ${runner_name} ..."
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --ipc=host --pid=host --net=host \
--name ${runner_name} \
-v $(pwd):/workspace \
-w /workspace \
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
-e "FLASK_PORT=${FLASK_PORT}" \
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-e TZ="Asia/Shanghai" \
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install ${fastdeploy_wheel_url}
python -m pip install pytest
wget https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64
chmod +x ./llm-deploy-linux-amd64
./llm-deploy-linux-amd64 -python python3.10 \
-model_name ERNIE-4.5-0.3B-Paddle \
-model_path /MODELDATA \
--skip install
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
pushd tests/ce/deploy
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
python3.10 deploy.py > dd.log 2>&1 &
sleep 3
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
-H "Content-Type: application/json" \
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}"
curl -X POST http://localhost:${FLASK_PORT}/wait_for_infer?timeout=90
popd
pushd tests/ce/accuracy_cases
export URL=http://localhost:${FD_API_PORT}/v1/chat/completions
export TEMPLATE=TOKEN_LOGPROB
export MODEL_SIZE=0.3B
TEST_EXIT_CODE=0
python gsm8k.py || TEST_EXIT_CODE=1
popd
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> /workspace/FastDeploy/exit_code.env
'
if [ -f ./FastDeploy/exit_code.env ]; then
source ./FastDeploy/exit_code.env
cat ./FastDeploy/exit_code.env >> $GITHUB_ENV
fi
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}"
exit ${TEST_EXIT_CODE}

View File

@@ -1,230 +0,0 @@
name: Base Test
description: "Run Base Tests"
on:
workflow_call:
inputs:
DOCKER_IMAGE:
description: "Build Images"
required: true
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
FASTDEPLOY_ARCHIVE_URL:
description: "URL of the compressed FastDeploy code archive."
required: true
type: string
FASTDEPLOY_WHEEL_URL:
description: "URL of the FastDeploy Wheel."
required: true
type: string
CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
MODEL_CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
jobs:
base_tests:
runs-on: [self-hosted, GPU-h20-1Cards]
timeout-minutes: 60
steps:
- name: Code Prepare
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
run: |
set -x
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
docker pull ${docker_image}
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}*
fi
'
wget -q ${fd_archive_url}
tar -xf FastDeploy.tar.gz
rm -rf FastDeploy.tar.gz
cd FastDeploy
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git log -n 3 --oneline
- name: Run FastDeploy Base Tests
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
run: |
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"
echo "FD_API_PORT=${FD_API_PORT}"
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
echo "DEVICES=${DEVICES}"
echo "========================================================="
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
fi
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
exit 1
fi
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
echo "========================================================="
echo "Ensuring no stale container named ${runner_name} ..."
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --ipc=host --pid=host --net=host \
--name ${runner_name} \
-v $(pwd):/workspace \
-w /workspace \
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
-e "FLASK_PORT=${FLASK_PORT}" \
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-e TZ="Asia/Shanghai" \
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install ${fastdeploy_wheel_url}
python -m pip install pytest
wget https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64
chmod +x ./llm-deploy-linux-amd64
./llm-deploy-linux-amd64 -python python3.10 \
-model_name ERNIE-4.5-0.3B-Paddle \
-model_path /MODELDATA \
--skip install
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
pushd tests/ce/deploy
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
python3.10 deploy.py > dd.log 2>&1 &
sleep 3
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
-H "Content-Type: application/json" \
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}"
check_service() {
local timeout=${1:-90}
local url="http://localhost:${FLASK_PORT}/wait_for_infer?timeout=${timeout}"
local resp
resp=$(curl -s -X POST "$url")
if echo "$resp" | grep -q "服务启动超时"; then
exit 8
fi
}
check_service 90
popd
pushd tests/ce/server
export URL=http://localhost:${FD_API_PORT}/v1/chat/completions
export TEMPLATE=TOKEN_LOGPROB
TEST_EXIT_CODE=0
python -m pytest -sv test_base_chat.py test_compare_top_logprobs.py test_logprobs.py test_params_boundary.py test_seed_usage.py test_stream.py test_evil_cases.py test_completions.py test_return_token_ids.py || TEST_EXIT_CODE=1
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
-H "Content-Type: application/json" \
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--early-stop-config\": \"{\\\"enable_early_stop\\\":true, \\\"window_size\\\":6, \\\"threshold\\\":0.93}\"}"
check_service 90
python -m pytest -sv test_repetition_early_stop.py || TEST_EXIT_CODE=1
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
-H "Content-Type: application/json" \
-d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--max-concurrency\": 5, \"--max-waiting-time\": 1 }"
check_service 90
python -m pytest -sv test_max_concurrency.py || TEST_EXIT_CODE=1
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
-H "Content-Type: application/json" \
-d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--max-concurrency\": 5000, \"--max-waiting-time\": 1 }"
check_service 90
python -m pytest -sv test_max_waiting_time.py || TEST_EXIT_CODE=1
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
-H "Content-Type: application/json" \
-d "{\"--model\": \"/MODELDATA/ernie-4_5-21b-a3b-bf16-paddle\", \"--config\": \"21b_mtp.yaml\", \"--enable-logprob\": \"False\"}"
check_service 180
export TEMPLATE=TOKEN_NORMAL
python -m pytest -sv test_seed_usage.py -k "not test_seed_stream" || TEST_EXIT_CODE=1
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
-H "Content-Type: application/json" \
-d "{\"--model\": \"/MODELDATA/ernie-4_5-21b-a3b-bf16-paddle\", \"--config\": \"21b_sot.yaml\", \"--enable-logprob\": \"False\"}"
check_service 360
export TEMPLATE=TOKEN_NORMAL
python -m pytest -sv test_seed_usage.py -k "not test_seed_stream" || TEST_EXIT_CODE=1
popd
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> /workspace/FastDeploy/exit_code.env
'
if [ -f ./FastDeploy/exit_code.env ]; then
source ./FastDeploy/exit_code.env
cat ./FastDeploy/exit_code.env >> $GITHUB_ENV
fi
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}"
exit ${TEST_EXIT_CODE}

View File

@@ -22,22 +22,12 @@ on:
description: "Enable nightly build mode (e.g. add date suffix to version)"
required: false
type: string
default: "OFF"
default: "ON"
FD_VERSION:
description: "FastDeploy Package Version"
required: false
type: string
default: ""
PADDLEVERSION:
description: "Paddle Version Build Use"
required: false
type: string
default: ""
PADDLE_WHL_URL:
description: "Paddle Wheel Package URL"
required: false
type: string
default: ""
UPLOAD:
description: "Upload Package"
required: false
@@ -54,8 +44,7 @@ on:
value: ${{ jobs.fd-build.outputs.wheel_path }}
jobs:
fd-build:
runs-on: [self-hosted, GPU-Build]
timeout-minutes: 360
runs-on: [self-hosted, GPU-h1z1-4Cards]
outputs:
wheel_path: ${{ steps.set_output.outputs.wheel_path }}
steps:
@@ -96,17 +85,13 @@ jobs:
compile_arch: ${{ inputs.COMPILE_ARCH }}
fd_version: ${{ inputs.FD_VERSION }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
BRANCH_REF: ${{ github.ref_name }}
PADDLEVERSION: ${{ inputs.PADDLEVERSION }}
PADDLE_WHL_URL: ${{ inputs.PADDLE_WHL_URL }}
WITH_NIGHTLY_BUILD: ${{ inputs.WITH_NIGHTLY_BUILD }}
run: |
set -x
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
CARD_ID=$(echo "${runner_name}" | cut -d'-' -f2)
gpu_id=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
CACHE_DIR=${CACHE_DIR:-${{ github.workspace }}}
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
@@ -118,15 +103,11 @@ jobs:
-v $(pwd):/workspace -w /workspace \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/.ccache:/root/.ccache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-e TZ="Asia/Shanghai" \
-e "COMPILE_ARCH=${compile_arch}" \
-e "FD_VERSION=${fd_version}" \
-e "WITH_NIGHTLY_BUILD=${WITH_NIGHTLY_BUILD}" \
-e "PADDLEVERSION=${PADDLEVERSION}" \
-e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \
-e "BRANCH_REF=${BRANCH_REF}" \
--gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c '
if [[ -n "${FD_VERSION}" ]]; then
export FASTDEPLOY_VERSION=${FD_VERSION}
@@ -134,7 +115,6 @@ jobs:
fi
git config --global --add safe.directory /workspace/FastDeploy
chown -R $(whoami) /workspace/FastDeploy
cd FastDeploy
if [[ "${WITH_NIGHTLY_BUILD}" == "ON" ]];then
GIT_COMMIT_TIME=$(git --no-pager show -s --format=%ci HEAD)
@@ -143,20 +123,14 @@ jobs:
echo "Date Only: $DATE_ONLY"
export FASTDEPLOY_VERSION="${FASTDEPLOY_VERSION}.dev${DATE_ONLY}"
fi
# 针对不同分支和tag使用不同的PaddlePaddle安装包
if [[ "${PADDLE_WHL_URL}" != "" ]];then
python -m pip install ${PADDLE_WHL_URL}
elif [[ "${PADDLEVERSION}" != "" ]];then
python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
else
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
fi
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
pip config set global.index-url http://pip.baidu.com/root/baidu/+simple/
pip config set install.trusted-host pip.baidu.com
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
python -m pip install wheel
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
# 编译RDMA
export ENABLE_FD_RDMA=1
bash build.sh 1 python false [${COMPILE_ARCH}]

View File

@@ -1,73 +0,0 @@
name: Docker Build
description: "FastDeploy CI Image Build"
on:
workflow_call:
inputs:
CI_DOCKER_IMAGE_NAME:
description: "Build Images"
required: true
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
FASTDEPLOY_ARCHIVE_URL:
description: "URL of the compressed FastDeploy code archive."
required: true
type: string
DOCKER_IMAGE_NAME:
description: "Build Images"
required: false
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate"
outputs:
docker_name_precheck:
description: "Output path of the generated wheel"
value: ${{ jobs.docker_build.outputs.docker_name_precheck }}
jobs:
docker_build:
runs-on: [self-hosted, Docker-Build]
outputs:
docker_name_precheck: ${{ steps.docker_build.outputs.docker_name_precheck }}
steps:
- name: Docker Build
id: docker_build
shell: bash
env:
docker_image_name: ${{ inputs.CI_DOCKER_IMAGE_NAME }}
docker_image: ${{ inputs.DOCKER_IMAGE_NAME }}
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
run: |
set -x
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}*
fi
'
wget -q ${fd_archive_url}
tar -xf FastDeploy.tar.gz
rm -rf FastDeploy.tar.gz
cd FastDeploy
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git log -n 3 --oneline
# Docker Build
cd tools/dockerfile/
set -e
cp ../../requirements.txt ./
cp ../../scripts/unittest_requirement.txt ./
docker build -t ${docker_image_name} -f Dockerfile.ci . \
--network host \
--no-cache
docker push ${docker_image_name}
echo "docker_name_precheck=${docker_image_name}" >> $GITHUB_OUTPUT

View File

@@ -68,7 +68,7 @@ jobs:
branch_name=${{ github.ref_name }}
target_path=paddle-github-action/BRANCH/FastDeploy/${branch_name}/${commit_id}
fi
wget -O bos_tools.py -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls

View File

@@ -1,185 +0,0 @@
name: Run FastDeploy LogProb Tests
description: "Run FastDeploy LogProb Tests"
on:
workflow_call:
inputs:
DOCKER_IMAGE:
description: "Build Images"
required: true
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
PADDLETEST_ARCHIVE_URL:
description: "URL of the compressed FastDeploy code archive."
required: true
type: string
default: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
FASTDEPLOY_WHEEL_URL:
description: "URL of the FastDeploy Wheel."
required: true
type: string
CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
MODEL_CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
jobs:
run_tests_logprob:
runs-on: [self-hosted, GPU-h20-1Cards]
steps:
- name: Code Prepare
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
paddletest_archive_url: ${{ inputs.PADDLETEST_ARCHIVE_URL }}
run: |
docker pull ${docker_image}
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
-e "BASE_BRANCH=${BASE_BRANCH}" \
${docker_image} /bin/bash -c '
rm -rf /workspace/*
'
wget -q ${paddletest_archive_url}
tar -xf PaddleTest.tar.gz
rm -rf PaddleTest.tar.gz
cd PaddleTest
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git log -n 3 --oneline
- name: logprob test
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
run: |
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"
echo "FD_API_PORT=${FD_API_PORT}"
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
echo "DEVICES=${DEVICES}"
echo "========================================================="
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
fi
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
exit 1
fi
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
echo "========================================================="
echo "Ensuring no stale container named ${runner_name} ..."
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --ipc=host --pid=host --net=host \
--name ${runner_name} \
-v $(pwd):/workspace \
-w /workspace \
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
-e "FLASK_PORT=${FLASK_PORT}" \
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-e TZ="Asia/Shanghai" \
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install ${fastdeploy_wheel_url}
wget https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64
chmod +x ./llm-deploy-linux-amd64
./llm-deploy-linux-amd64 -python python3.10 \
-model_name ERNIE-4.5-0.3B-Paddle \
-model_path /MODELDATA \
--skip install
cd PaddleTest/framework/ServeTest
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
python3.10 deploy.py > dd.log 2>&1 &
sleep 3
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
-H "Content-Type: application/json" \
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}"
curl -X POST http://localhost:${FLASK_PORT}/wait_for_infer?timeout=90
curl -s -o /dev/null -w "%{http_code}" -m 2 "http://0.0.0.0:${FD_API_PORT}/health"
curl -X POST "http://0.0.0.0:${FD_API_PORT}/v1/chat/completions" \
-H "Content-Type: application/json" \
-d "{\"messages\": [{\"role\": \"user\", \"content\": \"1+1=?\"}], \"logprobs\": true}"
set +e
rm -rf ./baseline_output
cp -r baseline/ERNIE-4.5-0.3B-Paddle ./baseline_output
LOGPROB_EXIT_CODE=0
python3.10 lanucher.py --request_template TOKEN_LOGPROB --url http://localhost:${FD_API_PORT}/v1/chat/completions --case ./cases/demo.yaml --concurrency 1 --name demo --exe logprob || LOGPROB_EXIT_CODE=$?
echo "LOGPROB_EXIT_CODE=${LOGPROB_EXIT_CODE}" > /workspace/exit_code.env
curl -X POST http://localhost:${FLASK_PORT}/stop
sleep 10s
cat *result.log
exit 0
'
if [ $? -ne 0 ];then
exit 1
fi
if [ -f exit_code.env ]; then
cat exit_code.env >> $GITHUB_ENV
fi
- name: logprob test result
if: ${{ env.LOGPROB_EXIT_CODE != 0 }}
shell: bash
run: |
echo "logprob test failed with exit code ${{ env.LOGPROB_EXIT_CODE }}"
exit 8

View File

@@ -1,148 +0,0 @@
name: Pre-CE-Test
on:
workflow_call:
inputs:
DOCKER_IMAGE:
description: "Build Images"
required: true
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126"
FASTDEPLOY_ARCHIVE_URL:
description: "URL of the compressed FastDeploy code archive."
required: true
type: string
FASTDEPLOY_WHEEL_URL:
description: "URL of the FastDeploy Wheel."
required: true
type: string
CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
MODEL_CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
jobs:
run_ce_cases:
runs-on: [self-hosted, PRE_CE_RUN_2Card]
timeout-minutes: 60
steps:
- name: Print current runner name
run: |
echo "Current runner name: ${{ runner.name }}"
- name: Code Prepare
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
run: |
set -x
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
docker pull ${docker_image}
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}*
fi
'
wget -q ${fd_archive_url}
tar -xf FastDeploy.tar.gz
rm -rf FastDeploy.tar.gz
cd FastDeploy
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git log -n 3 --oneline
- name: Run CI unittest
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
run: |
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"
echo "FD_API_PORT=${FD_API_PORT}"
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
echo "DEVICES=${DEVICES}"
echo "========================================================="
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
fi
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
echo "========================================================="
echo "Ensuring no stale container named ${runner_name} ..."
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --net=host \
--name ${runner_name} \
-v $(pwd):/workspace \
-w /workspace \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-v "${MODEL_CACHE_DIR}:/ModelData:ro" \
-e "MODEL_PATH=/ModelData" \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
-e "FLASK_PORT=${FLASK_PORT}" \
-e "fd_wheel_url=${fd_wheel_url}" \
--gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c '
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
python -m pip install ${fd_wheel_url}
bash scripts/run_pre_ce.sh
'

View File

@@ -1,170 +0,0 @@
name: Stable Test
description: "Run Stable Tests"
on:
workflow_call:
inputs:
DOCKER_IMAGE:
description: "Build Images"
required: true
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
FASTDEPLOY_ARCHIVE_URL:
description: "URL of the compressed FastDeploy code archive."
required: true
type: string
FASTDEPLOY_WHEEL_URL:
description: "URL of the FastDeploy Wheel."
required: true
type: string
CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
MODEL_CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
jobs:
stable_tests:
runs-on: [self-hosted, GPU-h1z1-2Cards]
timeout-minutes: 60
steps:
- name: Code Prepare
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
run: |
set -x
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
docker pull ${docker_image}
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}*
fi
'
wget -q ${fd_archive_url}
tar -xf FastDeploy.tar.gz
rm -rf FastDeploy.tar.gz
cd FastDeploy
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git log -n 3 --oneline
- name: Run FastDeploy Stable Tests
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
run: |
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42038 + DEVICE_PORT * 100))
FD_INFERENCE_MSG_QUEUE_ID=$(( 42048 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"
echo "FD_API_PORT=${FD_API_PORT}"
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
echo "FD_INFERENCE_MSG_QUEUE_ID=${FD_INFERENCE_MSG_QUEUE_ID}"
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
echo "DEVICES=${DEVICES}"
echo "========================================================="
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
fi
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
exit 1
fi
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
echo "========================================================="
echo "Ensuring no stale container named ${runner_name} ..."
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --ipc=host --pid=host --net=host \
--name ${runner_name} \
-v $(pwd):/workspace \
-w /workspace \
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
-e "FLASK_PORT=${FLASK_PORT}" \
-e "FD_INFERENCE_MSG_QUEUE_ID=${FD_INFERENCE_MSG_QUEUE_ID}" \
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-e TZ="Asia/Shanghai" \
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install ${fastdeploy_wheel_url}
python -m pip install pytest
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
TEST_EXIT_CODE=0
pushd tests/ce/stable_cases
bash launch_model.sh /MODELDATA
bash run.sh || TEST_EXIT_CODE=1
popd
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> /workspace/FastDeploy/exit_code.env
'
if [ -f ./FastDeploy/exit_code.env ]; then
source ./FastDeploy/exit_code.env
cat ./FastDeploy/exit_code.env >> $GITHUB_ENV
fi
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}"
exit ${TEST_EXIT_CODE}

View File

@@ -1,319 +0,0 @@
name: Coverage Check
description: "Run FastDeploy Unit Tests and Coverage"
on:
workflow_call:
inputs:
DOCKER_IMAGE:
description: "Build Images"
required: true
type: string
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
FASTDEPLOY_ARCHIVE_URL:
description: "URL of the compressed FastDeploy code archive."
required: true
type: string
FASTDEPLOY_WHEEL_URL:
description: "URL of the FastDeploy Wheel."
required: true
type: string
CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
MODEL_CACHE_DIR:
description: "Cache Dir Use"
required: false
type: string
default: ""
secrets:
github-token:
required: true
jobs:
check_cov_skip:
uses: ./.github/workflows/check-bypass.yml
secrets:
github-token: ${{ secrets.github-token }}
with:
workflow-name: coverage
run_tests_with_coverage:
runs-on: [self-hosted, GPU-h1z1-2Cards]
timeout-minutes: 90
needs: check_cov_skip
if: needs.check_cov_skip.outputs.can-skip != 'true'
outputs:
diff_cov_file_url: ${{ steps.cov_upload.outputs.diff_cov_file_url }}
unittest_failed_url: ${{ steps.cov_upload.outputs.unittest_failed_url }}
diff_cov_result_json_url: ${{ steps.cov_upload.outputs.diff_cov_result_json_url }}
steps:
- name: Code Prepare
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
run: |
set -x
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
docker pull ${docker_image}
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}*
fi
'
wget -q ${fd_archive_url}
tar -xf FastDeploy.tar.gz
rm -rf FastDeploy.tar.gz
cd FastDeploy
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git log -n 3 --oneline
- name: Run FastDeploy Unit Tests and Coverage
shell: bash
env:
docker_image: ${{ inputs.DOCKER_IMAGE }}
fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
CACHE_DIR: ${{ inputs.CACHE_DIR }}
BASE_REF: ${{ github.event.pull_request.base.ref }}
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
IS_PR: ${{ github.event_name == 'pull_request' }}
run: |
if [[ "$IS_PR" == "true" ]]; then
echo "Running on PR"
else
echo "Not a PR"
fi
runner_name="${{ runner.name }}"
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
echo "Test ENV Parameter:"
echo "========================================================="
echo "FLASK_PORT=${FLASK_PORT}"
echo "FD_API_PORT=${FD_API_PORT}"
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
echo "DEVICES=${DEVICES}"
echo "========================================================="
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
echo "CACHE_DIR is set to ${CACHE_DIR}"
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
touch "${CACHE_DIR}/gitconfig"
fi
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
echo "========================================================="
echo "Ensuring no stale container named ${runner_name} ..."
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
echo "Removing stale container: ${runner_name}"
docker rm -f ${runner_name} || true
fi
docker run --rm --net=host \
--name ${runner_name} \
--cap-add=SYS_PTRACE --shm-size=64G \
-v $(pwd):/workspace -w /workspace \
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
-v "${CACHE_DIR}/.cache:/root/.cache" \
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
-v "${MODEL_CACHE_DIR}:/ModelData:ro" \
-e "MODEL_PATH=/ModelData" \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
-e "FLASK_PORT=${FLASK_PORT}" \
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
-e TZ="Asia/Shanghai" \
-e "fd_wheel_url=${fd_wheel_url}" \
-e "BASE_REF=${BASE_REF}" \
-e "IS_PR=${IS_PR}" \
--gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c '
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
python -m pip install -r scripts/unittest_requirement.txt
python -m pip install ${fd_wheel_url}
rm -rf fastdeploy
# coverage subprocess use
python -m pip install ${fd_wheel_url} --no-deps --target=/workspace/FastDeploy
export PYTHONPATH=/workspace/FastDeploy/
if [ -d "tests/plugins" ]; then
cd tests/plugins
python setup.py install
cd ../..
else
echo "Warning: tests/plugins directory not found, skipping setup.py install"
fi
export COVERAGE_FILE=/workspace/FastDeploy/coveragedata/.coverage
export COVERAGE_RCFILE=/workspace/FastDeploy/scripts/.coveragerc
TEST_EXIT_CODE=0
bash scripts/coverage_run.sh || TEST_EXIT_CODE=8
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> exit_code.env
coverage combine coveragedata/ || echo "No data to combine"
coverage report
coverage xml -o python_coverage_all.xml
COVERAGE_EXIT_CODE=0
if [[ "$IS_PR" == "true" ]]; then
diff-cover python_coverage_all.xml --diff-file=diff.txt --fail-under=80 --json-report diff_coverage.json || COVERAGE_EXIT_CODE=9
python scripts/generate_diff_coverage_xml.py diff.txt python_coverage_all.xml
else
echo "Not a PR, skipping diff-cover"
fi
echo "COVERAGE_EXIT_CODE=${COVERAGE_EXIT_CODE}" >> exit_code.env
'
if [ -f FastDeploy/exit_code.env ]; then
cat FastDeploy/exit_code.env >> $GITHUB_ENV
fi
- name: Upload unit resule and diff coverage to bos
id: cov_upload
shell: bash
run: |
cd FastDeploy
commit_id=${{ github.event.pull_request.head.sha }}
pr_num=${{ github.event.pull_request.number }}
target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id}/SM${compile_arch//,/_}
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py -O bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
diff_cov_file="diff_coverage.xml"
if [ -f ${diff_cov_file} ];then
python ${push_file} ${diff_cov_file} ${target_path}/CoverageData
target_path_stripped="${target_path#paddle-github-action/}"
DIFF_COV_FILE_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${diff_cov_file}
echo "diff_cov_file_url=${DIFF_COV_FILE_URL}" >> $GITHUB_OUTPUT
echo "diff_cov_file_url=${DIFF_COV_FILE_URL}" >> $GITHUB_ENV
fi
diff_cov_result_json="diff_coverage.json"
if [ -f ${diff_cov_result_json} ];then
python ${push_file} ${diff_cov_result_json} ${target_path}/CoverageData
target_path_stripped="${target_path#paddle-github-action/}"
DIFF_COV_JSON_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${diff_cov_result_json}
echo "diff_cov_result_json_url=${DIFF_COV_JSON_URL}" >> $GITHUB_OUTPUT
echo "diff_cov_result_json_url=${DIFF_COV_JSON_URL}" >> $GITHUB_ENV
fi
unittest_result="failed_tests.log"
if [ -s ${unittest_result} ];then
python ${push_file} ${unittest_result} ${target_path}/UnitTestResult
target_path_stripped="${target_path#paddle-github-action/}"
UNIT_TEST_RESULT_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/UnitTestResult/${unittest_result}
echo "unittest_failed_url=${UNIT_TEST_RESULT_URL}" >> $GITHUB_OUTPUT
echo "unittest_failed_url=${UNIT_TEST_RESULT_URL}" >> $GITHUB_ENV
fi
- name: Check Unit Test Success
shell: bash
run: |
cd FastDeploy
if [ "$TEST_EXIT_CODE" -eq 8 ]; then
filename=$(basename "$unittest_failed_url")
if [ -z "${unittest_failed_url}" ]; then
echo "No diff unit failed file URL provided."
else
rm -rf "${filename}"
wget -O ${filename} ${unittest_failed_url} || echo "Download unittest file failed, but continuing..."
fi
echo "Unit tests failed (exit code 8)"
if [ -f "${filename}" ];then
echo "Failed test cases:"
cat "${filename}"
fi
exit "$TEST_EXIT_CODE"
fi
echo "All tests passed"
- name: Verify Code Coverage Threshold (80%)
if: ${{ github.event_name == 'pull_request' }}
shell: bash
run: |
cd FastDeploy
if [ "$COVERAGE_EXIT_CODE" -eq 9 ]; then
echo "Coverage generation failed (exit code 9)"
filename=$(basename "$diff_cov_result_json_url")
if [ -z "${diff_cov_result_json_url}" ]; then
echo "No diff cov result file URL provided."
else
rm -rf "${filename}"
wget -O ${filename} ${diff_cov_result_json_url} || echo "Download cov json file failed, but continuing..."
fi
if [ -f "${filename}" ];then
echo "Failed test cases:"
if command -v jq >/dev/null 2>&1; then
jq . "${filename}"
else
cat "${filename}"
fi
fi
exit "$COVERAGE_EXIT_CODE"
fi
echo "coverage passed"
exit 0
diff_coverage_report:
needs: run_tests_with_coverage
if: always()
runs-on: ubuntu-latest
env:
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
steps:
- name: coverage diff file download
shell: bash
env:
diff_cov_file_url: ${{ needs.run_tests_with_coverage.outputs.diff_cov_file_url }}
run: |
wget ${fd_archive_url}
tar -xf FastDeploy.tar.gz
cd FastDeploy
if [ -z "${diff_cov_file_url}" ]; then
echo "No diff coverage file URL provided."
exit 0
fi
wget "${diff_cov_file_url}" -O ./diff_coverage.xml || echo "Download cov file failed, but continuing..."
- name: Upload diff coverage report
if: ${{ needs.run_tests_with_coverage.outputs.diff_cov_file_url != null && needs.run_tests_with_coverage.outputs.diff_cov_file_url != '' }}
uses: codecov/codecov-action@v5
with:
files: ./FastDeploy/diff_coverage.xml
name: python diff coverage
verbose: true
disable_search: true
commit_parent: false
flags: diff

View File

@@ -1,42 +0,0 @@
name: Approval
on:
pull_request:
branches:
- develop
- 'release/*'
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
jobs:
Approval:
name: Approval
if: ${{ github.repository_owner == 'PaddlePaddle' }}
runs-on: ubuntu-latest
env:
PR_ID: ${{ github.event.pull_request.number }}
BRANCH: ${{ github.event.pull_request.base.ref }}
steps:
- name: Checkout base repo
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.base.ref }}
fetch-depth: 1000
- name: Merge PR to test branch
run: |
git fetch origin pull/${PR_ID}/merge
git checkout -b test FETCH_HEAD
git log -n 3 --oneline
git remote add upstream https://github.com/PaddlePaddle/FastDeploy.git
git fetch upstream $BRANCH
- name: Setup python3.10
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Run approval check script
run: |
bash scripts/check_approval.sh

View File

@@ -1,248 +0,0 @@
name: CE Compile Job
on:
workflow_dispatch:
push:
branches:
- develop
- 'release/*'
permissions: read-all
concurrency:
group: CE-Job-${{ github.ref }}-${{ github.sha }}
cancel-in-progress: true
jobs:
ce_job_pre_check:
runs-on: ubuntu-latest
env:
COMPILE_BRANCH: ${{ vars.COMPILE_BRANCH }}
CE_COMPILE_SELECTION: ${{ vars.CE_COMPILE_SELECTION }}
COMPILE_USE_PADDLE_WHL_URL_MAPPINGS: ${{ vars.COMPILE_USE_PADDLE_WHL_URL_MAPPINGS }}
outputs:
branch_match: ${{ steps.set_output.outputs.branch_match }}
compile_use_paddle_whl_url: ${{ steps.set_output.outputs.compile_use_paddle_whl_url }}
sm8689_match: ${{ steps.set_output.outputs.sm8689_match }}
sm8090_match: ${{ steps.set_output.outputs.sm8090_match }}
steps:
- name: Set Version
id: set_output
env:
COMPILE_BRANCH: ${{ env.COMPILE_BRANCH }}
CE_COMPILE_SELECTION: ${{ env.CE_COMPILE_SELECTION }}
COMPILE_USE_PADDLE_WHL_URL_MAPPINGS: ${{ env.COMPILE_USE_PADDLE_WHL_URL_MAPPINGS }}
GITHUB_REF_NAME: ${{ github.ref_name }}
run: |
# 选择要触发编译任务的分支 done
# 选择指定分支要编译的任务 8090或者8689
# 指定分支编译要使用的Paddle的安装包,默认使用nightly最新的
IFS=',' read -ra BRANCHES <<< "$COMPILE_BRANCH"
MATCH=false
for b in "${BRANCHES[@]}"; do
if [[ "$b" == "${GITHUB_REF_NAME}" ]]; then
MATCH=true
break
fi
done
echo "branch_match=$MATCH" >> $GITHUB_OUTPUT
# 通过变量CE_COMPILE_SELECTION中的映射关系,决定分支是编译sm8090还是sm8689
for pair in $(echo "$CE_COMPILE_SELECTION" | tr ';' ' '); do
branch=$(echo "$pair" | cut -d',' -f1)
compile_task_list=$(echo "$pair" | cut -d',' -f2)
if [[ "$branch" == "$GITHUB_REF_NAME" ]]; then
# 判断里面是否包含 sm8090 或 sm8689
if [[ "$compile_task_list" == *"sm8090"* ]]; then
echo "sm8090_match=true" >> $GITHUB_OUTPUT
fi
if [[ "$compile_task_list" == *"sm8689"* ]]; then
echo "sm8689_match=true" >> $GITHUB_OUTPUT
fi
break
fi
done
# 通过变量COMPILE_USE_PADDLE_WHL_URL_MAPPINGS中的映射关系,决定是否是安装指定版本的Paddle还是直接安装URL
for pair in $(echo $COMPILE_USE_PADDLE_WHL_URL_MAPPINGS | tr ';' ' '); do
branch=$(echo "$pair" | cut -d',' -f1)
paddle_whl_url=$(echo "$pair" | cut -d',' -f2)
if [[ "$branch" == "${{ github.ref_name }}" ]]; then
FOUND_PADDLE_URL="$paddle_whl_url"
echo "compile_use_paddle_whl_url=${FOUND_PADDLE_URL}" >> $GITHUB_OUTPUT
break
fi
done
print_ce_job_pre_check_outputs:
runs-on: ubuntu-latest
needs: ce_job_pre_check
steps:
- name: Print outputs as JSON
run: |
echo '${{ toJSON(needs.ce_job_pre_check.outputs) }}'
clone:
environment: CodeSync
name: FD-Clone-Linux
runs-on: ubuntu-latest
needs: ce_job_pre_check
if: ${{ needs.ce_job_pre_check.outputs.branch_match == 'true' }}
outputs:
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
steps:
- name: Clone FastDeploy
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request'
&& github.event.pull_request.base.ref
|| github.ref_name }}
submodules: 'recursive'
fetch-depth: 1000
- name: Python Setup
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Code Info Show and Upload
id: set_output
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
run: |
git config --unset http.https://github.com/.extraheader
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
echo "Current HEAD Log:"
git log --oneline -n 5
ls
cd ..
tar -zcf FastDeploy.tar.gz FastDeploy
commit_id=${{ github.sha }}
branch_name=${{ github.ref_name }}
target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} FastDeploy.tar.gz ${target_path}
target_path_stripped="${target_path#paddle-qa/}"
REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
resultshow:
name: Show Code Archive Output
needs: clone
runs-on: ubuntu-latest
steps:
- name: Print wheel path
run: |
echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}"
build_sm8090:
name: BUILD_SM8090
needs: [clone, ce_job_pre_check]
if: ${{ needs.ce_job_pre_check.outputs.sm8090_match == 'true' }}
uses: ./.github/workflows/_build_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "80,90"
WITH_NIGHTLY_BUILD: OFF
FD_VERSION: 0.0.0
PADDLE_WHL_URL: ${{ needs.ce_job_pre_check.outputs.compile_use_paddle_whl_url }}
build_sm8689:
name: BUILD_SM8689
needs: [clone, ce_job_pre_check]
if: ${{ needs.ce_job_pre_check.outputs.sm8689_match == 'true' }}
uses: ./.github/workflows/_build_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "86,89"
WITH_NIGHTLY_BUILD: OFF
FD_VERSION: 0.0.0
PADDLE_WHL_URL: ${{ needs.ce_job_pre_check.outputs.compile_use_paddle_whl_url }}
ce_upload_sm8090:
environment: CodeSync
name: CE_UPLOAD
needs: build_sm8090
runs-on: ubuntu-latest
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
COMPILE_ARCH: "80,90"
steps:
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Wheel Info Show and Upload
run: |
echo "The wheel is located at: ${{ needs.build_sm8090.outputs.wheel_path }}"
wget -q --no-check-certificate ${{ needs.build_sm8090.outputs.wheel_path }}
filename=$(basename ${{ needs.build_sm8090.outputs.wheel_path }})
commit_id=${{ github.sha }}
branch_name=${{ github.ref_name }}
target_path=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/${commit_id}
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} ${filename} ${target_path}
target_path_stripped="${target_path#paddle-qa/}"
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${filename}
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
python ${push_file} ${filename} ${target_path_latest}
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${filename}
echo "commit wheel url is ${WHEEL_PATH}"
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
ce_upload_sm8689:
environment: CodeSync
name: CE_UPLOAD
needs: build_sm8689
runs-on: ubuntu-latest
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8689.outputs.wheel_path }}
COMPILE_ARCH: "86,89"
steps:
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Wheel Info Show and Upload
run: |
echo "The wheel is located at: ${{ needs.build_sm8689.outputs.wheel_path }}"
wget -q --no-check-certificate ${{ needs.build_sm8689.outputs.wheel_path }}
filename=$(basename ${{ needs.build_sm8689.outputs.wheel_path }})
commit_id=${{ github.sha }}
branch_name=${{ github.ref_name }}
target_path=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/${commit_id}
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} ${filename} ${target_path}
target_path_stripped="${target_path#paddle-qa/}"
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${filename}
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
python ${push_file} ${filename} ${target_path_latest}
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${filename}
echo "commit wheel url is ${WHEEL_PATH}"
echo "latest wheel url is ${WHEEL_PATH_LATEST}"

View File

@@ -1,51 +0,0 @@
on:
workflow_call:
inputs:
workflow-name:
required: true
type: string
secrets:
github-token:
required: true
outputs:
can-skip:
description: "Whether the workflow can be skipped."
value: ${{ jobs.check-bypass.outputs.can-skip }}
jobs:
check-bypass:
name: Check bypass
runs-on: ubuntu-latest
permissions:
contents: read
env:
CI_TEAM_MEMBERS: '["yuanlehome","YuanRisheng","Jiang-Jia-Jun","DDDivano","XieYunshen"]'
outputs:
can-skip: ${{ steps.check-bypass.outputs.can-skip }}
steps:
- name: Cleanup
run: |
rm -rf * .[^.]*
- id: check-bypass
name: Check Bypass
uses: PFCCLab/ci-bypass@v1
with:
github-token: ${{ secrets.github-token }}
non-pull-request-event-strategy: 'never-skipped'
type: 'composite'
composite-rule: |
{
"any": [
{
"type": "labeled",
"label": ["skip-ci: ${{ inputs.workflow-name }}", "skip-ci: all"],
"username": ${{ env.CI_TEAM_MEMBERS }}
},
{
"type": "commented",
"comment-pattern": [".*/skip-ci ${{ inputs.workflow-name }}.*", ".*/skip-ci all.*"],
"username": ${{ env.CI_TEAM_MEMBERS }}
}
]
}

109
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,109 @@
name: CI
on:
pull_request:
branches:
- develop
- 'release/*'
workflow_dispatch:
concurrency:
group: ${{ github.event.pull_request.number }}
cancel-in-progress: true
jobs:
build:
runs-on: [self-hosted, GPU-L20-4Card]
steps:
- name: Print current runner name
run: |
echo "Current runner name: ${{ runner.name }}"
# Because the system version is lower than 2.23, the checkout cannot be used.
# - name: Checkout code
# uses: actions/checkout@v4
- name: Code Checkout
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126
run: |
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
-e "BASE_BRANCH=${BASE_BRANCH}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}
fi
'
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
cd FastDeploy
if [ "${{ github.event_name }}" = "pull_request" ]; then
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
git merge pr/${{ github.event.pull_request.number }}
git log -n 3 --oneline
else
git checkout ${{ github.sha }}
git log -n 3 --oneline
fi
- name: Run CI unittest
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126
run: |
runner_name="${{ runner.name }}"
last_char="${runner_name: -1}"
if [ "${last_char}" = "1" ]; then
gpu_id=2
DEVICES="2,3"
else
gpu_id=0
DEVICES="0,1"
fi
FLASK_PORT=$((41068 + gpu_id * 100))
FD_API_PORT=$((41088 + gpu_id * 100))
FD_ENGINE_QUEUE_PORT=$((41058 + gpu_id * 100))
FD_METRICS_PORT=$((41078 + gpu_id * 100))
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT)
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
echo "==== LOG_FILE is ${LOG_FILE} ===="
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
for port in "${PORTS[@]}"; do
PIDS=$(lsof -t -i :$port || true)
if [ -n "$PIDS" ]; then
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
echo "$PIDS" | xargs -r kill -9
echo "Port $port cleared" | tee -a $LOG_FILE
else
echo "Port $port is free" | tee -a $LOG_FILE
fi
done
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
PARENT_DIR=$(dirname "$WORKSPACE")
echo "PARENT_DIR:$PARENT_DIR"
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
-v "/ssd4/GithubActions/gitconfig:/etc/gitconfig:ro" \
-v "/ssd4/GithubActions/ModelData:/ModelData:ro" \
-v "/ssd4/GithubActions/CacheDir:/root/.cache" \
-v "/ssd4/GithubActions/ConfigDir:/root/.config" \
-e "MODEL_PATH=/ModelData" \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -c "
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
bash scripts/run_ci.sh
"

View File

@@ -1,98 +0,0 @@
name: CI_GCU
on:
pull_request:
branches:
- develop
- 'release/*'
workflow_dispatch:
concurrency:
group: ${{ github.event.pull_request.number }}-gcu-ci
cancel-in-progress: true
jobs:
CI_GCU:
runs-on:
group: GCU
steps:
- name: Print current runner name
run: |
echo "Current runner name: ${{ runner.name }}"
- name: Code Checkout
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84
run: |
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
REPO_NAME="${FULL_REPO##*/}"
BASE_BRANCH="${{ github.base_ref }}"
# Clean the repository directory before starting
docker run --rm --net=host -v $(pwd):/workspace \
-v ${{ github.workspace }}/../../..:${{ github.workspace }}/../../.. \
-w /workspace \
-e "REPO_NAME=${REPO_NAME}" \
-e "BASE_BRANCH=${BASE_BRANCH}" \
${docker_image} /bin/bash -c '
if [ -d ${REPO_NAME} ]; then
echo "Directory ${REPO_NAME} exists, removing it..."
rm -rf ${REPO_NAME}
fi
'
git config --global user.name "FastDeployCI"
git config --global user.email "fastdeploy_ci@example.com"
source ${{ github.workspace }}/../../../proxy
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
cd FastDeploy
if [ "${{ github.event_name }}" = "pull_request" ]; then
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
git merge pr/${{ github.event.pull_request.number }}
git log -n 3 --oneline
else
git checkout ${{ github.sha }}
git log -n 3 --oneline
fi
echo "Copy models..."
sudo mkdir -p ci_models && sudo cp -r /work/deps/ERNIE-4.5-21B-A3B-Paddle ci_models
echo "Copy models done."
- name: Run CI unittest
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84
run: |
runner_name="${{ runner.name }}"
last_char="${runner_name: -1}"
if [[ "$last_char" =~ [0-3] ]]; then
gcu_id="$last_char"
else
gcu_id="0"
fi
FD_API_PORT=$((9180 + gcu_id * 100))
FD_ENGINE_QUEUE_PORT=$((9150 + gcu_id * 100))
FD_METRICS_PORT=$((9170 + gcu_id * 100))
PARENT_DIR=$(dirname "$WORKSPACE")
echo "PARENT_DIR:$PARENT_DIR"
echo "Install drivers..."
cd /work/deps
sudo bash TopsRider_i3x_*_deb_amd64.run --driver --no-auto-load -y
cd -
echo "Create docker..."
docker run --rm --network=host --ipc=host --privileged \
-v $(pwd):/workspace \
-v /home:/home \
-v /work:/work \
-w /workspace \
-e "MODEL_PATH=./ci_models" \
-e "http_proxy=$(git config --global --get http.proxy)" \
-e "https_proxy=$(git config --global --get https.proxy)" \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
${docker_image} /bin/bash -c "
git config --global --add safe.directory /workspace/FastDeploy
cd FastDeploy
bash scripts/run_ci_gcu.sh
"

View File

@@ -11,8 +11,7 @@ concurrency:
jobs:
CI_ILUVATAR:
runs-on:
group: IXUCA
runs-on: [self-hosted, IXUCA]
steps:
- name: Print current runner name
run: |

View File

@@ -1,174 +0,0 @@
name: CI Images Build
on:
workflow_dispatch:
schedule:
- cron: '0 18 * * *' # 2:00 AM China Standard Time (UTC+8)
permissions: read-all
concurrency:
group: CI-Images-Build-${{ github.ref }}-${{ github.sha }}
cancel-in-progress: true
jobs:
clone:
environment: CodeSync
name: FD-Clone-Linux
runs-on: ubuntu-latest
outputs:
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
steps:
- name: Clone FastDeploy
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name }}
submodules: 'recursive'
fetch-depth: 1000
- name: Python Setup
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Code Info Show and Upload
id: set_output
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
run: |
git config --unset http.https://github.com/.extraheader
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
echo "Current HEAD Log:"
git log --oneline -n 5
ls
cd ..
tar -zcf FastDeploy.tar.gz FastDeploy
if [[ "${{ github.ref_type }}" == "tag" ]]; then
commit_id=${{ github.sha }}
tag_name=${{ github.ref_name }}
target_path=paddle-qa/TAG/FastDeploy/${tag_name}/${commit_id}
else
commit_id=${{ github.sha }}
branch_name=${{ github.ref_name }}
target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}
fi
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} FastDeploy.tar.gz ${target_path}
target_path_stripped="${target_path#paddle-qa/}"
REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
resultshow:
name: Show Code Archive Output
needs: clone
runs-on: ubuntu-latest
steps:
- name: Print wheel path
run: |
echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}"
ci_image_build:
name: CI Images Build
needs: clone
uses: ./.github/workflows/_ci_image_build.yml
with:
CI_DOCKER_IMAGE_NAME: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate-precheck
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
build_sm8090:
name: BUILD_SM8090
needs: [clone, ci_image_build]
uses: ./.github/workflows/_build_linux.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "90"
WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }}
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }}
unittest_coverage:
name: Run FastDeploy Unit Tests and Coverage
needs: [clone,build_sm8090,ci_image_build]
uses: ./.github/workflows/_unit_test_coverage.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
logprob_test:
name: Run FastDeploy LogProb Tests
needs: [build_sm8090,ci_image_build]
uses: ./.github/workflows/_logprob_test_linux.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
pre_ce_test:
name: Extracted partial CE model tasks to run in CI.
needs: [clone,build_sm8090,ci_image_build]
uses: ./.github/workflows/_pre_ce_test.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
base_test:
name: Run Base Tests
needs: [clone,build_sm8090,ci_image_build]
uses: ./.github/workflows/_base_test.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
accuracy_test:
name: Run Accuracy Tests
needs: [clone,build_sm8090,ci_image_build]
uses: ./.github/workflows/_accuracy_test.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
stable_test:
name: Run Stable Tests
needs: [clone,build_sm8090,ci_image_build]
uses: ./.github/workflows/_stable_test.yml
with:
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
publish_pre_check:
name: Publish Docker Images Pre Check
needs: [ci_image_build, unittest_coverage,logprob_test,pre_ce_test,base_test,accuracy_test,stable_test]
runs-on: [self-hosted, Docker-Build]
steps:
- name: Images Uploading
env:
images_name: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
ci_image_name: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate"
run: |
echo "images_name=${images_name}"
docker images ${ci_image_name}
docker tag ${images_name} ${ci_image_name}
docker push ${ci_image_name}

View File

@@ -24,7 +24,7 @@ jobs:
- name: Code Checkout
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.1.0
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
run: |
REPO="https://github.com/${{ github.repository }}.git"
FULL_REPO="${{ github.repository }}"
@@ -55,7 +55,7 @@ jobs:
- name: Run CI unittest
env:
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.1.0
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
run: |
runner_name="${{ runner.name }}"
last_char="${runner_name: -1}"
@@ -77,7 +77,6 @@ jobs:
-e "MODEL_PATH=/ssd3/model" \
-e "http_proxy=$(git config --global --get http.proxy)" \
-e "https_proxy=$(git config --global --get https.proxy)" \
-e "no_proxy=bcebos.com,mirrors.tuna.tsinghua.edu.cn,127.0.0.1,localhost" \
-e "FD_API_PORT=${FD_API_PORT}" \
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \

View File

@@ -15,7 +15,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: 3.x
- run: pip install mkdocs-material mkdocs-get-deps mkdocs-material-extensions mkdocs-multilang mkdocs-static-i18n
- run: pip install mkdocs-material mkdocs-get-deps mkdocs-material-extensions mkdocs-multilang
- name: Deploy to GitHub Pages
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -19,7 +19,7 @@ jobs:
needs: clone
uses: ./.github/workflows/_build_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "90"
WITH_NIGHTLY_BUILD: "OFF"
@@ -33,65 +33,3 @@ jobs:
- name: Print wheel path
run: |
echo "The built wheel is located at: ${{ needs.build.outputs.wheel_path }}"
unittest_coverage:
name: Run FastDeploy Unit Tests and Coverage
needs: [clone,build]
uses: ./.github/workflows/_unit_test_coverage.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
logprob_test:
name: Run FastDeploy LogProb Tests
needs: [build]
uses: ./.github/workflows/_logprob_test_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
pre_ce_test:
name: Extracted partial CE model tasks to run in CI.
needs: [clone,build]
uses: ./.github/workflows/_pre_ce_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
base_test:
name: Run Base Tests
needs: [clone,build]
uses: ./.github/workflows/_base_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
accuracy_test:
name: Run Accuracy Tests
needs: [clone,build]
uses: ./.github/workflows/_accuracy_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
stable_test:
name: Run Stable Tests
needs: [clone,build]
uses: ./.github/workflows/_stable_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"

View File

@@ -1,331 +0,0 @@
name: Publish Job
on:
workflow_dispatch:
schedule:
- cron: '0 18 * * *' # 2:00 AM China Standard Time (UTC+8)
push:
# branches:
# - develop
tags:
- '*'
permissions: read-all
concurrency:
group: Publish-Job-${{ github.ref }}-${{ github.sha }}
cancel-in-progress: true
jobs:
publish_pre_check:
runs-on: ubuntu-latest
if: |
github.event.repository.fork == false &&
(
(github.event_name == 'schedule' && github.ref_name == 'develop') ||
(github.event_name == 'push' && github.ref_type == 'tag') ||
((github.event_name == 'workflow_dispatch') &&
(github.ref_name == 'develop' || github.ref_type == 'tag'))
)
env:
TAG_VERSION_MAPPINGS: ${{ vars.TAG_VERSION_MAPPINGS }}
FD_VERSION_DEV: ${{ vars.FD_VERSION_DEV }}
COMPILE_USE_PADDLE_WHL_URL_MAPPINGS: ${{ vars.COMPILE_USE_PADDLE_WHL_URL_MAPPINGS }}
outputs:
compile_use_paddle_version: ${{ steps.set_output.outputs.compile_use_paddle_version }}
compile_continue: ${{ steps.set_output.outputs.compile_continue }}
fd_version: ${{ steps.set_output.outputs.fd_version }}
with_nightly_build: ${{ steps.set_output.outputs.with_nightly_build }}
compile_use_paddle_whl_url: ${{ steps.set_output.outputs.compile_use_paddle_whl_url }}
steps:
- name: Get tag version
if: github.ref_type == 'tag'
run: |
TAG_NAME="${GITHUB_REF##*/}" # 提取 tag 名称,比如 v2.1.0
TAG_VERSION="${TAG_NAME#v}" # 去掉前缀 v
echo "FD_VERSION=$TAG_VERSION" >> $GITHUB_ENV
- name: Check FD version to Paddle version mapping
if: github.ref_type == 'tag'
env:
TARGET_FD: ${{ env.FD_VERSION }}
run: |
FOUND_PADDLE=""
# 遍历映射
for pair in $(echo $TAG_VERSION_MAPPINGS | tr ';' ' '); do
fd=$(echo "$pair" | cut -d',' -f1)
paddle=$(echo "$pair" | cut -d',' -f2)
if [[ "$fd" == "$TARGET_FD" ]]; then
FOUND_PADDLE="$paddle"
break
fi
done
if [[ -z "$FOUND_PADDLE" ]]; then
echo "No Paddle version found for FD $TARGET_FD"
else
echo "FD $TARGET_FD maps to Paddle $FOUND_PADDLE"
echo "PADDLE_VERSION=$FOUND_PADDLE" >> $GITHUB_ENV
fi
- name: Set Version
id: set_output
env:
PADDLE_VERSION: ${{ env.PADDLE_VERSION }}
FD_VERSION: ${{ env.FD_VERSION }}
run: |
if [[ "${{ github.ref_type }}" == "tag" ]]; then
if [[ -z "$PADDLE_VERSION" ]]; then
compile_continue=false
else
compile_use_paddle_version=$PADDLE_VERSION
compile_continue=true
fi
fd_version=$FD_VERSION
fi
if [[ "${{ github.ref_name }}" == "develop" ]];then
compile_continue=true
compile_use_paddle_version=""
fd_version=${FD_VERSION_DEV}
with_nightly_build=ON
fi
# Todo
# 通过变量COMPILE_USE_PADDLE_WHL_URL_MAPPINGS中的映射关系,决定是否是安装指定版本的Paddle还是直接安装URL
for pair in $(echo $COMPILE_USE_PADDLE_WHL_URL_MAPPINGS | tr ';' ' '); do
branch=$(echo "$pair" | cut -d',' -f1)
paddle_whl_url=$(echo "$pair" | cut -d',' -f2)
if [[ "$branch" == "${{ github.ref_name }}" ]]; then
FOUND_PADDLE_URL="$paddle_whl_url"
echo "compile_use_paddle_whl_url=${FOUND_PADDLE_URL}" >> $GITHUB_OUTPUT
compile_continue=true
break
fi
done
echo "compile_continue=${compile_continue}" >> $GITHUB_OUTPUT
echo "compile_use_paddle_version=${compile_use_paddle_version}" >> $GITHUB_OUTPUT
echo "fd_version=${fd_version}" >> $GITHUB_OUTPUT
echo "with_nightly_build=${with_nightly_build:-OFF}" >> $GITHUB_OUTPUT
print_publish_pre_check_outputs:
runs-on: ubuntu-latest
needs: publish_pre_check
steps:
- name: Print outputs as JSON
run: |
echo '${{ toJSON(needs.publish_pre_check.outputs) }}'
clone:
environment: CodeSync
name: FD-Clone-Linux
runs-on: ubuntu-latest
needs: publish_pre_check
if: ${{ needs.publish_pre_check.outputs.compile_continue == 'true' }}
outputs:
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
steps:
- name: Clone FastDeploy
uses: actions/checkout@v4
with:
ref: ${{ github.ref_name }}
submodules: 'recursive'
fetch-depth: 1000
- name: Python Setup
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Code Info Show and Upload
id: set_output
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
run: |
git config --unset http.https://github.com/.extraheader
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
echo "Current HEAD Log:"
git log --oneline -n 5
ls
cd ..
tar -zcf FastDeploy.tar.gz FastDeploy
if [[ "${{ github.ref_type }}" == "tag" ]]; then
commit_id=${{ github.sha }}
tag_name=${{ github.ref_name }}
target_path=paddle-qa/TAG/FastDeploy/${tag_name}/${commit_id}
else
commit_id=${{ github.sha }}
branch_name=${{ github.ref_name }}
target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}
fi
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} FastDeploy.tar.gz ${target_path}
target_path_stripped="${target_path#paddle-qa/}"
REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
resultshow:
name: Show Code Archive Output
needs: clone
runs-on: ubuntu-latest
steps:
- name: Print wheel path
run: |
echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}"
build_sm8090:
name: BUILD_SM8090
needs: [clone, publish_pre_check]
uses: ./.github/workflows/_build_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "80,90"
WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }}
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }}
build_sm8689:
name: BUILD_SM8689
needs: [clone, publish_pre_check]
uses: ./.github/workflows/_build_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
COMPILE_ARCH: "86,89"
WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }}
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }}
paddle_pypi_upload_sm8090:
environment: PaddleSourceUpload
name: PADDLE_PYPI_UPLOAD_8090
needs: build_sm8090
runs-on: ubuntu-latest
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
COMPILE_ARCH: "80,90"
steps:
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Wheel Info Show and Upload
if: github.ref_name == 'develop' || github.ref_type == 'tag'
run: |
echo "The wheel is located at: ${FASTDEPLOY_WHEEL_URL}"
wget -q --no-check-certificate ${FASTDEPLOY_WHEEL_URL}
filename=$(basename ${FASTDEPLOY_WHEEL_URL})
if [[ "${{ github.ref_name }}" == "develop" ]];then
target_path=paddle-whl/nightly/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
elif [[ "${{ github.ref_type }}" == "tag" ]]; then
target_path=paddle-whl/stable/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
else
echo "Not develop or tag, do nothing"
fi
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} ${filename} ${target_path}
paddle_pypi_upload_sm8689:
environment: PaddleSourceUpload
name: PADDLE_PYPI_UPLOAD_8689
needs: build_sm8689
runs-on: ubuntu-latest
env:
AK: ${{ secrets.BOS_AK }}
SK: ${{ secrets.BOS_SK }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8689.outputs.wheel_path }}
COMPILE_ARCH: "86,89"
steps:
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Wheel Info Show and Upload
if: github.ref_name == 'develop' || github.ref_type == 'tag'
run: |
echo "The wheel is located at: ${FASTDEPLOY_WHEEL_URL}"
wget -q --no-check-certificate ${FASTDEPLOY_WHEEL_URL}
filename=$(basename ${FASTDEPLOY_WHEEL_URL})
if [[ "${{ github.ref_name }}" == "develop" ]];then
target_path=paddle-whl/nightly/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
elif [[ "${{ github.ref_type }}" == "tag" ]]; then
target_path=paddle-whl/stable/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
else
echo "Not develop or tag, do nothing"
fi
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
push_file=$(realpath bos_tools.py)
python -m pip install bce-python-sdk==0.9.29
ls
python ${push_file} ${filename} ${target_path}
unittest_coverage:
name: Run FastDeploy Unit Tests and Coverage
needs: [clone,build_sm8090]
uses: ./.github/workflows/_unit_test_coverage.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
logprob_test:
name: Run FastDeploy LogProb Tests
needs: [build_sm8090]
uses: ./.github/workflows/_logprob_test_linux.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
pre_ce_test:
name: Extracted partial CE model tasks to run in CI.
needs: [clone,build_sm8090]
uses: ./.github/workflows/_pre_ce_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
base_test:
name: Run Base Tests
needs: [clone,build_sm8090]
uses: ./.github/workflows/_base_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
accuracy_test:
name: Run Accuracy Tests
needs: [clone,build_sm8090]
uses: ./.github/workflows/_accuracy_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
stable_test:
name: Run Stable Tests
needs: [clone,build_sm8090]
uses: ./.github/workflows/_stable_test.yml
with:
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"

14
.gitignore vendored
View File

@@ -121,7 +121,7 @@ dmypy.json
FETCH_HEAD
#log
log/
log*/
checkpoints/
checkpoints_origin/
@@ -156,12 +156,6 @@ nohup.out
custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cutlass
custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cute
#marlin_kernel
custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_*.cu
#machete_kernel
custom_ops/gpu_ops/machete/generated
# buff
custom_ops/tmp*
@@ -170,9 +164,3 @@ build
.ccls-cache
third_party
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_*.cu
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_template.h
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_*.cu
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_template.h

10
.gitmodules vendored
View File

@@ -1,10 +0,0 @@
[submodule "custom_ops/third_party/DeepGEMM"]
path = custom_ops/third_party/DeepGEMM
url = https://github.com/deepseek-ai/DeepGEMM.git
ignore = all
[submodule "custom_ops/third_party/cutlass"]
path = custom_ops/third_party/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "custom_ops/third_party/nlohmann_json"]
path = custom_ops/third_party/nlohmann_json
url = https://github.com/nlohmann/json.git

View File

@@ -1,4 +1,3 @@
English | [简体中文](README_CN.md)
<p align="center">
<a href="https://github.com/PaddlePaddle/FastDeploy/releases"><img src="https://github.com/user-attachments/assets/42b0039f-39e3-4279-afda-6d1865dfbffb" width="500"></a>
</p>
@@ -23,12 +22,11 @@ English | [简体中文](README_CN.md)
</p>
--------------------------------------------------------------------------------
# FastDeploy : Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle
# FastDeploy 2.0: Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle
## News
**[2025-09] 🔥 FastDeploy v2.2 is newly released!** It now offers compatibility with models in the HuggingFace ecosystem, has further optimized performance, and newly adds support for [baidu/ERNIE-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking)!
**[2025-08] 🔥 Released FastDeploy v2.1:** A brand-new KV Cache scheduling strategy has been introduced, and expanded support for PD separation and CUDA Graph across more models. Enhanced hardware support has been added for platforms like Kunlun and Hygon, along with comprehensive optimizations to improve the performance of both the service and inference engine.
**[2025-07] 《FastDeploy2.0推理部署实测》专题活动已上线!** 完成文心4.5系列开源模型的推理部署等任务即可获得骨瓷马克杯等FastDeploy2.0官方周边及丰富奖金!🎁 欢迎大家体验反馈~ 📌[报名地址](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[活动详情](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
**[2025-07] The FastDeploy 2.0 Inference Deployment Challenge is now live!** Complete the inference deployment task for the ERNIE 4.5 series open-source models to win official FastDeploy 2.0 merch and generous prizes! 🎁 You're welcome to try it out and share your feedback! 📌[Sign up here](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[Event details](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
@@ -52,16 +50,14 @@ English | [简体中文](README_CN.md)
## Installation
FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**, **Iluvatar GPUs**, **Enflame GCUs**, **Hygon DCUs** and other hardware. For detailed installation instructions:
FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**, **Iluvatar GPUs**, **Enflame GCUs**, and other hardware. For detailed installation instructions:
- [NVIDIA GPU](./docs/get_started/installation/nvidia_gpu.md)
- [Kunlunxin XPU](./docs/get_started/installation/kunlunxin_xpu.md)
- [Iluvatar GPU](./docs/get_started/installation/iluvatar_gpu.md)
- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md)
- [Hygon DCU](./docs/get_started/installation/hygon_dcu.md)
- [MetaX GPU](./docs/get_started/installation/metax_gpu.md)
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU are currently under development and testing. Stay tuned for updates!
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU, Hygon DCU, and MetaX GPU are currently under development and testing. Stay tuned for updates!
## Get Started
@@ -71,12 +67,19 @@ Learn how to use FastDeploy through our documentation:
- [ERNIE-4.5-VL Multimodal Model Deployment](./docs/get_started/ernie-4.5-vl.md)
- [Offline Inference Development](./docs/offline_inference.md)
- [Online Service Deployment](./docs/online_serving/README.md)
- [Best Practices](./docs/best_practices/README.md)
- [Full Supported Models List](./docs/supported_models.md)
## Supported Models
Learn how to download models, enable using the torch format, and more:
- [Full Supported Models List](./docs/supported_models.md)
| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length |
|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- |
|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅(WINT4)| WIP |128K |
|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|✅(WINT4)| WIP | 128K |
|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K |
|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K |
|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | WIP | ✅|128K |
|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | WIP | ✅|128K |
|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ❌ | ✅ | ✅ | ❌ | ✅| 128K |
## Advanced Usage

View File

@@ -1,89 +0,0 @@
[English](README.md) | 简体中文
<p align="center">
<a href="https://github.com/PaddlePaddle/FastDeploy/releases"><img src="https://github.com/user-attachments/assets/42b0039f-39e3-4279-afda-6d1865dfbffb" width="500"></a>
</p>
<p align="center">
<a href=""><img src="https://img.shields.io/badge/python-3.10-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/os-linux-pink.svg"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/graphs/contributors"><img src="https://img.shields.io/github/contributors/PaddlePaddle/FastDeploy?color=9ea"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/commits"><img src="https://img.shields.io/github/commit-activity/m/PaddlePaddle/FastDeploy?color=3af"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/FastDeploy?color=9cc"></a>
<a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a>
</p>
<p align="center">
<a href="https://trendshift.io/repositories/4046" target="_blank"><img src="https://trendshift.io/api/badge/repositories/4046" alt="PaddlePaddle%2FFastDeploy | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></br>
<a href="https://paddlepaddle.github.io/FastDeploy/zh/get_started/installation/nvidia_gpu/"><b> 安装指导 </b></a>
|
<a href="https://paddlepaddle.github.io/FastDeploy/zh/get_started/quick_start"><b> 快速入门 </b></a>
|
<a href="https://paddlepaddle.github.io/FastDeploy/zh/supported_models/"><b> 支持模型列表 </b></a>
</p>
--------------------------------------------------------------------------------
# FastDeploy :基于飞桨的大语言模型与视觉语言模型推理部署工具包
## 最新活动
**[2025-09] 🔥 FastDeploy v2.2 全新发布**: HuggingFace生态模型兼容性能进一步优化更新增对[baidu/ERNIE-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking)支持!
**[2025-08] FastDeploy v2.1 发布**:全新的KV Cache调度策略更多模型支持PD分离和CUDA Graph昆仑、海光等更多硬件支持增强全方面优化服务和推理引擎的性能。
**[2025-07] 《FastDeploy2.0推理部署实测》专题活动已上线!** 完成文心4.5系列开源模型的推理部署等任务即可获得骨瓷马克杯等FastDeploy2.0官方周边及丰富奖金!🎁 欢迎大家体验反馈~ 📌[报名地址](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[活动详情](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
## 关于
**FastDeploy** 是基于飞桨PaddlePaddle的大语言模型LLM与视觉语言模型VLM推理部署工具包提供**开箱即用的生产级部署方案**,核心技术特性包括:
- 🚀 **负载均衡式PD分解**工业级解决方案支持上下文缓存与动态实例角色切换在保障SLO达标和吞吐量的同时优化资源利用率
- 🔄 **统一KV缓存传输**轻量级高性能传输库支持智能NVLink/RDMA选择
- 🤝 **OpenAI API服务与vLLM兼容**:单命令部署,兼容[vLLM](https://github.com/vllm-project/vllm/)接口
- 🧮 **全量化格式支持**W8A16、W8A8、W4A16、W4A8、W2A16、FP8等
-**高级加速技术**推测解码、多令牌预测MTP及分块预填充
- 🖥️ **多硬件支持**NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU等
## 要求
- 操作系统: Linux
- Python: 3.10 ~ 3.12
## 安装
FastDeploy 支持在**英伟达NVIDIAGPU**、**昆仑芯KunlunxinXPU**、**天数IluvatarGPU**、**燧原EnflameGCU**、**海光HygonDCU** 以及其他硬件上进行推理部署。详细安装说明如下:
- [英伟达 GPU](./docs/zh/get_started/installation/nvidia_gpu.md)
- [昆仑芯 XPU](./docs/zh/get_started/installation/kunlunxin_xpu.md)
- [天数 CoreX](./docs/zh/get_started/installation/iluvatar_gpu.md)
- [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md)
- [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md)
- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md)
**注意:** 我们正在积极拓展硬件支持范围。目前包括昇腾AscendNPU 等其他硬件平台正在开发测试中。敬请关注更新!
## 入门指南
通过我们的文档了解如何使用 FastDeploy
- [10分钟快速部署](./docs/zh/get_started/quick_start.md)
- [ERNIE-4.5 部署](./docs/zh/get_started/ernie-4.5.md)
- [ERNIE-4.5-VL 部署](./docs/zh/get_started/ernie-4.5-vl.md)
- [离线推理](./docs/zh/offline_inference.md)
- [在线服务](./docs/zh/online_serving/README.md)
- [最佳实践](./docs/zh/best_practices/README.md)
## 支持模型列表
通过我们的文档了解如何下载模型如何支持torch格式等
- [模型支持列表](./docs/zh/supported_models.md)
## 进阶用法
- [量化](./docs/zh/quantization/README.md)
- [分离式部署](./docs/zh/features/disaggregated.md)
- [投机解码](./docs/zh/features/speculative_decoding.md)
- [前缀缓存](./docs/zh/features/prefix_caching.md)
- [分块预填充](./docs/zh/features/chunked_prefill.md)
## 致谢
FastDeploy 依据 [Apache-2.0 开源许可证](./LICENSE). 进行授权。在开发过程中,我们参考并借鉴了 [vLLM](https://github.com/vllm-project/vllm) 的部分代码,以保持接口兼容性,在此表示衷心感谢。

View File

@@ -361,7 +361,8 @@ async def benchmark(
if not test_output.success:
raise ValueError(
f"Initial test run failed - Please make sure that 1. benchmark arguments are correctly specified and 2. the http_proxy and https_proxy are turned off. Error: {test_output.error}"
"Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}"
)
else:
print("Initial test run completed. Starting main benchmark run...")

View File

@@ -6,4 +6,3 @@ tensor_parallel_size: 8
max_num_batched_tokens: 4096
max_num_partial_prefills: 3
max_long_partial_prefills: 3
quantization: wint4

View File

@@ -1,6 +0,0 @@
tensor_parallel_size: 1
max_model_len: 131072
max_num_seqs: 32
quantization: wint4
max_num_batched_tokens: 8192
plas_attention_config: '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}'

View File

@@ -6,4 +6,3 @@ tensor_parallel_size: 8
max_num_batched_tokens: 4096
max_num_partial_prefills: 3
max_long_partial_prefills: 3
quantization: wint8

View File

@@ -1,5 +0,0 @@
max_model_len: 32768
max_num_seqs: 256
kv_cache_ratio: 0.75
tensor_parallel_size: 4
gpu_memory_utilization: 0.9

View File

@@ -13,4 +13,3 @@ pd_comm_port: "2334"
max_num_batched_tokens: 384
max_num_partial_prefills: 3
max_long_partial_prefills: 3
quantization: wint4

View File

@@ -10,4 +10,3 @@ engine_worker_queue_port: 6677
cache_transfer_protocol: "rdma,ipc"
rdma_comm_ports: "7675,7676,7677,7678"
pd_comm_port: "2333"
quantization: wint4

View File

@@ -1,6 +0,0 @@
num_gpu_blocks_override: 1024
max_model_len: 8192
max_num_seqs: 64
data_parallel_size: 8
tensor_parallel_size: 1
enable_expert_parallel: True

View File

@@ -1,11 +0,0 @@
enable_mm: True
max_model_len: 131072
max_num_seqs: 56
gpu_memory_utilization: 0.8
kv_cache_ratio: 0.8
tensor_parallel_size: 8
quantization: wint4
limit_mm_per_prompt: '{"image": 100, "video": 100}'
enable_chunked_prefill: True
max_num_batched_tokens: 384
reasoning_parser: ernie-45-vl

View File

@@ -1,7 +1,7 @@
enable_mm: True
max_model_len: 32768
max_num_seqs: 36
gpu_memory_utilization: 0.9
gpu_memory_utilization: 0.95
kv_cache_ratio: 0.8
tensor_parallel_size: 8
quantization: wint8

View File

@@ -1,7 +1,7 @@
enable_mm: True
max_model_len: 32768
max_num_seqs: 36
gpu_memory_utilization: 0.85
gpu_memory_utilization: 0.8
kv_cache_ratio: 0.8
tensor_parallel_size: 8
quantization: wint8

View File

@@ -1,9 +0,0 @@
enable_mm: True
max_model_len: 32768
max_num_seqs: 128
gpu_memory_utilization: 0.9
kv_cache_ratio: 0.71
tensor_parallel_size: 1
enable_chunked_prefill: True
max_num_batched_tokens: 384
reasoning_parser: ernie-45-vl

View File

@@ -1,10 +0,0 @@
enable_mm: True
max_model_len: 32768
max_num_seqs: 128
gpu_memory_utilization: 0.9
kv_cache_ratio: 0.71
tensor_parallel_size: 1
enable_chunked_prefill: True
max_num_batched_tokens: 384
quantization: wint4
reasoning_parser: ernie-45-vl

View File

@@ -1,10 +0,0 @@
enable_mm: True
max_model_len: 32768
max_num_seqs: 128
gpu_memory_utilization: 0.9
kv_cache_ratio: 0.71
tensor_parallel_size: 1
enable_chunked_prefill: True
max_num_batched_tokens: 384
quantization: wint8
reasoning_parser: ernie-45-vl

View File

@@ -1 +0,0 @@
max_tokens: 131071

View File

@@ -1 +0,0 @@
max_tokens: 12288

View File

@@ -1,8 +0,0 @@
top_p: 0.95
temperature: 0.6
metadata:
min_tokens: 1
max_tokens: 131071
repetition_penalty: 1.0
frequency_penalty: 0
presence_penalty: 0

View File

@@ -1,10 +0,0 @@
reasoning-parser: ernie_x1
tool_call_parser: ernie_x1
tensor_parallel_size: 4
max_model_len: 65536
max_num_seqs: 128
enable_prefix_caching: True
enable_chunked_prefill: True
gpu_memory_utilization: 0.85
use_cudagraph: True
enable_custom_all_reduce: True

View File

@@ -1,6 +0,0 @@
tensor_parallel_size: 1
max_model_len: 131072
max_num_seqs: 32
reasoning_parser: ernie_x1
tool_call_parser: ernie_x1
load_choices: "default_v1"

View File

@@ -34,6 +34,7 @@ EGG_DIR="fastdeploy.egg-info"
# custom_ops directory config
OPS_SRC_DIR="custom_ops"
OPS_TMP_DIR_BASE="tmp_base"
OPS_TMP_DIR="tmp"
# command line log config
@@ -70,20 +71,25 @@ function copy_ops(){
PY_VERSION="py${PY_MAIN_VERSION}.${PY_SUB_VERSION}"
SYSTEM_VERSION=`${python} -c "import platform; print(platform.system().lower())"`
PROCESSOR_VERSION=`${python} -c "import platform; print(platform.processor())"`
WHEEL_BASE_NAME="fastdeploy_base_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg"
WHEEL_NAME="fastdeploy_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg"
WHEEL_CPU_NAME="fastdeploy_cpu_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg"
is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"`
if [ "$is_rocm" = "True" ]; then
DEVICE_TYPE="rocm"
mkdir -p ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
echo -e "ROCM ops have been copy to fastdeploy"
echo -e "BASE and ROCM ops have been copy to fastdeploy"
return
fi
mkdir -p ../fastdeploy/model_executor/ops/base
is_cuda=`$python -c "import paddle; print(paddle.is_compiled_with_cuda())"`
if [ "$is_cuda" = "True" ]; then
DEVICE_TYPE="gpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
echo -e "CUDA ops have been copy to fastdeploy"
echo -e "BASE and CUDA ops have been copy to fastdeploy"
return
fi
@@ -106,8 +112,9 @@ function copy_ops(){
if_corex=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device(\"iluvatar_gpu\"))"`
if [ "$if_corex" = "True" ]; then
DEVICE_TYPE="iluvatar-gpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/iluvatar
echo -e "Iluvatar ops have been copy to fastdeploy"
echo -e "BASE and Iluvatar ops have been copy to fastdeploy"
return
fi
@@ -119,33 +126,27 @@ function copy_ops(){
return
fi
is_maca=`$python -c "import paddle; print(paddle.device.is_compiled_with_custom_device('metax_gpu'))"`
if [ "$is_maca" = "True" ]; then
DEVICE_TYPE="metax_gpu"
mkdir -p ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
echo -e "MACA ops have been copy to fastdeploy"
return
fi
DEVICE_TYPE="cpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cd ../../../../
cp -r ${OPS_TMP_DIR}/${WHEEL_CPU_NAME}/* ../fastdeploy/model_executor/ops/cpu
echo -e "CPU ops have been copy to fastdeploy"
echo -e "BASE and CPU ops have been copy to fastdeploy"
return
}
function build_and_install_ops() {
cd $OPS_SRC_DIR
export no_proxy=bcebos.com,paddlepaddle.org.cn,${no_proxy}
echo -e "${BLUE}[build]${NONE} build and install fastdeploy_base_ops..."
${python} setup_ops_base.py install --install-lib ${OPS_TMP_DIR_BASE}
find ${OPS_TMP_DIR_BASE} -type f -name "*.o" -exec rm -f {} \;
echo -e "${BLUE}[build]${NONE} build and install fastdeploy_ops..."
TMP_DIR_REAL_PATH=`readlink -f ${OPS_TMP_DIR}`
is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"`
if [ "$is_xpu" = "True" ]; then
cd xpu_ops
cd xpu_ops/src
bash build.sh ${TMP_DIR_REAL_PATH}
cd ..
cd ../..
elif [ "$FD_CPU_USE_BF16" == "true" ]; then
if [ "$FD_BUILDING_ARCS" == "" ]; then
FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
@@ -212,6 +213,7 @@ function cleanup() {
fi
rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR
rm -rf $OPS_SRC_DIR/$OPS_TMP_DIR_BASE
rm -rf $OPS_SRC_DIR/$OPS_TMP_DIR
}

View File

@@ -84,6 +84,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
seq_length,
bsz);
return {x_remove_padding,
cum_offsets_out,
padding_offset,
cu_seqlens_q,
cu_seqlens_k};
@@ -96,7 +97,7 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
const std::vector<int64_t> &seq_len_shape) {
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
}
std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
@@ -105,6 +106,7 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
const paddle::DataType &token_num_dtype,
const paddle::DataType &seq_len_dtype) {
return {input_ids_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype};
@@ -113,6 +115,7 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
PD_BUILD_STATIC_OP(get_padding_offset_cpu)
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
.Outputs({"x_remove_padding",
"cum_offsets_out",
"padding_offset",
"cu_seqlens_q",
"cu_seqlens_k"})

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -19,11 +19,10 @@
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
template <typename T>
void RebuildPaddingCPUImpl(T *output_data,
const T *input_data,
const int *cu_seqlens_q_data,
const int *cum_offsets_data,
const int *seq_len_this_time_data,
const int *seq_lens_decoder_data,
const int *seq_lens_encoder_data,
@@ -41,12 +40,11 @@ void RebuildPaddingCPUImpl(T *output_data,
if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) {
continue;
}
if (seq_lens_encoder_data[bi] > 0) {
seq_id = seq_lens_encoder_data[bi] - 1;
}
const int ori_token_idx = cu_seqlens_q_data[bi] + seq_id;
const int ori_token_idx =
bi * max_input_length - cum_offsets_data[bi] + seq_id;
const int src_offset = ori_token_idx * dim_embed + bias_idx;
output_data[i] = input_data[src_offset];
@@ -56,7 +54,7 @@ void RebuildPaddingCPUImpl(T *output_data,
template <typename T>
void RebuildAppendPaddingCPUImpl(T *output_data,
const T *input_data,
const int *cu_seqlens_q_data,
const int *cum_offsets_data,
const int *seq_len_this_time_data,
const int *seq_lens_decoder_data,
const int *seq_lens_encoder_data,
@@ -71,32 +69,30 @@ void RebuildAppendPaddingCPUImpl(T *output_data,
int bi = ori_token_id / max_input_length;
if (seq_len_this_time_data[bi] == 0 ||
(seq_lens_decoder_data[bi] == 0 &&
seq_lens_encoder_data[bi] == 0)) {
continue;
}
seq_lens_encoder_data[bi] == 0)) {
continue;
}
int seq_id = 0;
if (seq_lens_encoder_data[bi] > 0) {
seq_id = seq_lens_encoder_data[bi] - 1;
}
int input_token_id = cu_seqlens_q_data[bi] + seq_id;
int input_token_id = ori_token_id - cum_offsets_data[bi] + seq_id;
int bias_idx = i % dim_embed;
int src_offset = input_token_id * dim_embed + bias_idx;
output_data[i] = input_data[src_offset];
}
}
std::vector<paddle::Tensor> RebuildPaddingCPU(
const paddle::Tensor &tmp_out,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &seq_len_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) {
auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true);
auto cu_seqlens_q_cpu = cu_seqlens_q.copy_to(paddle::CPUPlace(), true);
auto cum_offsets_cpu = cum_offsets.copy_to(paddle::CPUPlace(), true);
auto seq_len_this_time_cpu =
seq_len_this_time.copy_to(paddle::CPUPlace(), true);
auto seq_lens_decoder_cpu =
@@ -111,7 +107,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
int token_num = tmp_out_cpu.shape()[0];
int dim_embed = tmp_out_cpu.shape()[1];
int bsz = cu_seqlens_q_cpu.shape()[0] - 1;
int bsz = cum_offsets_cpu.shape()[0];
paddle::Tensor out;
if (output_padding_offset_cpu) {
@@ -132,7 +128,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
{bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace());
}
const int *cu_seqlens_q_data = cu_seqlens_q_cpu.data<int>();
const int *cum_offsets_data = cum_offsets_cpu.data<int>();
const int *seq_len_this_time_data = seq_len_this_time_cpu.data<int>();
const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data<int>();
@@ -145,7 +141,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
case paddle::DataType::FLOAT32:
RebuildAppendPaddingCPUImpl<float>(out.data<float>(),
tmp_out_cpu.data<float>(),
cu_seqlens_q_data,
cum_offsets_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
@@ -158,7 +154,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
RebuildAppendPaddingCPUImpl<paddle::float16>(
out.data<paddle::float16>(),
tmp_out_cpu.data<paddle::float16>(),
cu_seqlens_q_data,
cum_offsets_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
@@ -171,7 +167,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
RebuildAppendPaddingCPUImpl<paddle::bfloat16>(
out.data<paddle::bfloat16>(),
tmp_out_cpu.data<paddle::bfloat16>(),
cu_seqlens_q_data,
cum_offsets_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
@@ -190,7 +186,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
case paddle::DataType::FLOAT32:
RebuildPaddingCPUImpl<float>(out.data<float>(),
tmp_out_cpu.data<float>(),
cu_seqlens_q_data,
cum_offsets_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
@@ -202,7 +198,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
RebuildPaddingCPUImpl<paddle::float16>(
out.data<paddle::float16>(),
tmp_out_cpu.data<paddle::float16>(),
cu_seqlens_q_data,
cum_offsets_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
@@ -211,10 +207,11 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
elem_nums);
break;
case paddle::DataType::BFLOAT16:
RebuildPaddingCPUImpl<paddle::bfloat16>(
out.data<paddle::bfloat16>(),
tmp_out_cpu.data<paddle::bfloat16>(),
cu_seqlens_q_data,
cum_offsets_data,
seq_len_this_time_data,
seq_lens_decoder_data,
seq_lens_encoder_data,
@@ -233,7 +230,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
const std::vector<int64_t> &tmp_out_shape,
const std::vector<int64_t> &cu_seqlens_q_shape,
const std::vector<int64_t> &cum_offsets_shape,
const std::vector<int64_t> &seq_len_this_time_shape,
const std::vector<int64_t> &seq_lens_decoder_shape,
const std::vector<int64_t> &seq_lens_encoder_shape,
@@ -242,14 +239,14 @@ std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
if (output_padding_offset_shape) {
return {{-1, dim_embed}};
} else {
int64_t bsz = cu_seqlens_q_shape[0] - 1;
int64_t bsz = cum_offsets_shape[0];
return {{bsz, dim_embed}};
}
}
std::vector<paddle::DataType> RebuildPaddingInferDtype(
const paddle::DataType &tmp_out_dtype,
const paddle::DataType &cu_seqlens_q_dtype,
const paddle::DataType &cum_offsets_dtype,
const paddle::DataType &seq_len_this_time_dtype,
const paddle::DataType &seq_lens_decoder_dtype,
const paddle::DataType &seq_lens_encoder_dtype,
@@ -259,7 +256,7 @@ std::vector<paddle::DataType> RebuildPaddingInferDtype(
PD_BUILD_STATIC_OP(rebuild_padding_cpu)
.Inputs({"tmp_out",
"cu_seqlens_q",
"cum_offsets",
"seq_len_this_time",
"seq_lens_decoder",
"seq_lens_encoder",

View File

@@ -14,7 +14,7 @@
#include "paddle/extension.h"
void set_value_by_flags_and_idx(const bool *stop_flags,
void set_value_by_flag_and_id(const bool *stop_flags,
int64_t *pre_ids_all,
const int64_t *input_ids,
const int *seq_lens_encoder,
@@ -50,7 +50,7 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
int length = pre_ids_all_shape[1];
int length_input_ids = input_ids.shape()[1];
set_value_by_flags_and_idx(stop_flags.data<bool>(),
set_value_by_flag_and_id(stop_flags.data<bool>(),
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
input_ids.data<int64_t>(),
seq_lens_encoder.data<int>(),

View File

@@ -46,7 +46,7 @@ void update_inputs_kernel(bool *not_need_stop,
not_need_stop[0] = stop_sum < stop_nums[0];
}
void UpdateInputs(const paddle::Tensor &stop_flags,
void UpdateInputes(const paddle::Tensor &stop_flags,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
@@ -90,4 +90,4 @@ PD_BUILD_STATIC_OP(update_inputs_cpu)
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"input_ids", "input_ids_out"}})
.SetKernelFn(PD_KERNEL(UpdateInputs));
.SetKernelFn(PD_KERNEL(UpdateInputes));

View File

@@ -38,7 +38,7 @@ class type2value<phi::dtype::float16> {
template <paddle::DataType D>
void AppendAttentionKernel(
std::vector<paddle::Tensor> AppendAttentionKernel(
const AppendAttnMetaData& meta_data,
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
@@ -60,7 +60,6 @@ void AppendAttentionKernel(
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& set_max_lengths,
const paddle::Tensor& max_len_kv,
paddle::Tensor& fmha_out,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& qkv_bias,
@@ -73,11 +72,7 @@ void AppendAttentionKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
@@ -123,6 +118,27 @@ void AppendAttentionKernel(
} else {
qkv_out = qkv;
}
paddle::Tensor fmha_out;
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::INT8,
qkv.place());
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::FLOAT8_E4M3FN,
qkv.place());
}else{
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
}
} else {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
D,
qkv.place());
}
auto dispatch_CascadeAppendAttentionKernel = [&](auto temp_args,
const paddle::Tensor& lambda_batch_ids,
@@ -140,8 +156,8 @@ void AppendAttentionKernel(
key_cache,
value_cache,
attn_mask,
cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales : cache_k_dequant_scales,
cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales : cache_v_dequant_scales,
cache_k_dequant_scales,
cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,
@@ -207,10 +223,7 @@ void AppendAttentionKernel(
main_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
const_cast<paddle::Tensor*>(&value_cache));
};
if (qkv_out_scales) {
@@ -257,6 +270,54 @@ void AppendAttentionKernel(
if (speculate_decoder) {
if (qkv_out_scales) {
SpeculateWriteCacheWithRoPEKernel<data_t, int>(
meta_data,
qkv, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));
} else {
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
meta_data,
qkv_out, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));
}
} else {
if (qkv_out_scales) {
DecoderWriteCacheWithRoPEKernel<data_t, int>(
meta_data,
qkv, // [token_num, num_heads, head_dim]
seq_lens_decoder,
@@ -278,12 +339,9 @@ void AppendAttentionKernel(
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
const_cast<paddle::Tensor*>(&value_cache));
} else {
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
DecoderWriteCacheWithRoPEKernel<data_t, data_t>(
meta_data,
qkv_out, // [token_num, num_heads, head_dim]
seq_lens_decoder,
@@ -305,64 +363,7 @@ void AppendAttentionKernel(
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
} else {
if (qkv_out_scales) {
DecoderWriteCacheWithRoPEKernel<data_t, int>(
meta_data,
qkv, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
} else {
DecoderWriteCacheWithRoPEKernel<data_t, data_t>(
meta_data,
qkv_out, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
const_cast<paddle::Tensor*>(&value_cache));
}
}
@@ -391,6 +392,8 @@ void AppendAttentionKernel(
cudaStreamWaitEvent(main_stream, decoder_event);
}
}
return {fmha_out, qkv_out};
}
std::vector<paddle::Tensor> AppendAttention(
@@ -426,11 +429,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
@@ -465,60 +464,8 @@ std::vector<paddle::Tensor> AppendAttention(
meta_data.block_size = key_cache.dims()[2];
meta_data.batch_size = seq_lens_this_time.dims()[0];
// template dtype generation
phi::DataType dtype_id;
switch (qkv.dtype()) {
case paddle::DataType::FLOAT16: {dtype_id = phi::DataType::FLOAT16; break;}
case paddle::DataType::BFLOAT16: {dtype_id = phi::DataType::BFLOAT16; break;}
case paddle::DataType::INT32: {
if (compute_dtype == "bf16") {
dtype_id = phi::DataType::BFLOAT16;
break;
} else if (compute_dtype == "fp16") {
dtype_id = phi::DataType::FLOAT16;
break;
} else {
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
break;
}
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16 and bfloat16 are supported. ");
break;
}
}
// fmha_out generation, rewrite from AppendAttentionKernel
paddle::Tensor fmha_out;
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::INT8,
qkv.place());
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::FLOAT8_E4M3FN,
qkv.place());
} else{
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
}
} else {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
dtype_id,
qkv.place());
}
if (mask_offset) {
meta_data.mask_offset = mask_offset.get().data<int>();
}
auto dispatch_by_template = [&](auto temp_args) -> void {
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
meta_data,
qkv,
key_cache,
@@ -540,7 +487,6 @@ std::vector<paddle::Tensor> AppendAttention(
decoder_num_blocks,
set_max_lengths,
max_len_kv,
fmha_out,
rotary_embs,
attn_mask,
qkv_bias,
@@ -553,11 +499,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
mask_offset,
kv_signal_data,
q_norm_weight,
k_norm_weight,
rms_norm_eps,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
@@ -572,198 +514,35 @@ std::vector<paddle::Tensor> AppendAttention(
speculate_max_draft_token_num,
causal,
speculate_decoder);
};
};
phi::dtype::float16 fp16_dtype;
phi::dtype::bfloat16 bp16_dtype;
switch (dtype_id){
case phi::DataType::FLOAT16: {
dispatch_by_template(fp16_dtype);
return {fmha_out};
}
case phi::DataType::BFLOAT16: {
dispatch_by_template(bp16_dtype);
return {fmha_out};
}
default:
PD_THROW(
switch (qkv.dtype()) {
case paddle::DataType::FLOAT16: return dispatch_by_template(fp16_dtype);
case paddle::DataType::BFLOAT16: return dispatch_by_template(bp16_dtype);
case paddle::DataType::INT32: {
if (compute_dtype == "bf16") {
return dispatch_by_template(bp16_dtype);
} else if (compute_dtype == "fp16") {
return dispatch_by_template(fp16_dtype);
} else {
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
break;
}
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16 and bfloat16 are supported. ");
break;
break;
}
}
return {paddle::Tensor{}};
}
void AppendAttentionWithOutput(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
const paddle::Tensor& encoder_num_blocks,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids_per_batch,
const paddle::Tensor& kv_num_blocks,
const paddle::Tensor& decoder_batch_ids,
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& set_max_lengths,
const paddle::Tensor& max_len_kv,
paddle::Tensor& fmha_out,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& attn_mask,
const paddle::optional<paddle::Tensor>& qkv_bias,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
const paddle::optional<paddle::Tensor>& cache_k_zp,
const paddle::optional<paddle::Tensor>& cache_v_zp,
const paddle::optional<paddle::Tensor>& out_linear_shifts,
const paddle::optional<paddle::Tensor>& out_linear_smooths,
const paddle::optional<paddle::Tensor>& mask_offset,
const paddle::optional<paddle::Tensor>& kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_input_length,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
AppendAttnMetaData meta_data;
const auto& qkv_dims = qkv.dims();
const auto& key_cache_dims = key_cache.dims();
meta_data.token_nums = qkv_dims[0];
meta_data.kv_num_heads = key_cache_dims[1];
meta_data.head_dims = key_cache_dims[3];
// TODO: trick method support c4, add attr head_dims in the future
if (cache_quant_type_str == "cache_int4_zp") {
meta_data.head_dims *= 2;
}
const int total_num_head =
qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims;
meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads;
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = key_cache.dims()[2];
meta_data.batch_size = seq_lens_this_time.dims()[0];
if (mask_offset) {
meta_data.mask_offset = mask_offset.get().data<int>();
}
auto dispatch_by_template = [&](auto temp_args) -> void {
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
meta_data,
qkv,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
batch_id_per_token,
cu_seqlens_q,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
set_max_lengths,
max_len_kv,
fmha_out,
rotary_embs,
attn_mask,
qkv_bias,
qkv_out_scales,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_dequant_scales,
cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,
out_linear_smooths,
mask_offset,
kv_signal_data,
q_norm_weight,
k_norm_weight,
rms_norm_eps,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
encoder_block_shape_q,
decoder_block_shape_q,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
speculate_decoder);
};
phi::dtype::float16 fp16_dtype;
phi::dtype::bfloat16 bp16_dtype;
switch (qkv.dtype()) {
case paddle::DataType::FLOAT16: {
dispatch_by_template(fp16_dtype);
break;
}
case paddle::DataType::BFLOAT16: {
dispatch_by_template(bp16_dtype);
break;
}
case paddle::DataType::INT32: {
if (compute_dtype == "bf16") {
dispatch_by_template(bp16_dtype);
break;
} else if (compute_dtype == "fp16") {
dispatch_by_template(fp16_dtype);
break;
} else {
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
break;
}
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16 and bfloat16 are supported. ");
break;
}
}
}
std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const std::vector<int64_t>& qkv_shape,
const std::vector<int64_t>& key_cache_shape,
@@ -797,11 +576,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
const paddle::optional<std::vector<int64_t>>& mask_offset_shape,
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
@@ -825,7 +600,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
}
const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim;
const int num_heads = total_num_head - 2 * kv_num_heads;
return {{token_num, num_heads * head_dim}};
return {{token_num, num_heads * head_dim}, qkv_shape};
}
std::vector<paddle::DataType> AppendAttentionInferDtype(
@@ -861,11 +636,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
const paddle::optional<paddle::DataType>& mask_offset_dtype,
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
@@ -884,148 +655,32 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
if (compute_dtype == "bf16") {
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
return {paddle::DataType::INT8};
return {paddle::DataType::INT8, paddle::DataType::BFLOAT16};
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
return {paddle::DataType::FLOAT8_E4M3FN};
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::BFLOAT16};
}else{
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
}
} else {
return {paddle::DataType::BFLOAT16};
return {paddle::DataType::BFLOAT16, paddle::DataType::BFLOAT16};
}
} else if (compute_dtype == "fp16") {
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
return {paddle::DataType::INT8};
return {paddle::DataType::INT8, paddle::DataType::FLOAT16};
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
return {paddle::DataType::FLOAT8_E4M3FN};
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT16};
}else{
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
}
} else {
return {paddle::DataType::FLOAT16};
return {paddle::DataType::FLOAT16, paddle::DataType::FLOAT16};
}
} else {
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
}
}
std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
const std::vector<int64_t>& qkv_shape,
const std::vector<int64_t>& key_cache_shape,
const std::vector<int64_t>& value_cache_shape,
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& batch_id_per_token_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& block_tables_shape,
const std::vector<int64_t>& encoder_batch_ids_shape,
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& encoder_num_blocks_shape,
const std::vector<int64_t>& kv_batch_ids_shape,
const std::vector<int64_t>& kv_tile_ids_per_batch_shape,
const std::vector<int64_t>& kv_num_blocks_shape,
const std::vector<int64_t>& decoder_batch_ids_shape,
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& decoder_num_blocks_shape,
const std::vector<int64_t>& set_max_lengths_shape,
const std::vector<int64_t>& max_len_kv_shape,
const std::vector<int64_t>& fmha_out_shape,
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
const paddle::optional<std::vector<int64_t>>& qkv_bias_shape,
const paddle::optional<std::vector<int64_t>>& qkv_out_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_k_quant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_v_quant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_k_dequant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_v_dequant_scales_shape,
const paddle::optional<std::vector<int64_t>>& cache_k_zp_shape,
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
const paddle::optional<std::vector<int64_t>>& mask_offset_shape,
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_input_length,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
return {fmha_out_shape};
}
std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
const paddle::DataType& qkv_dtype,
const paddle::DataType& key_cache_dtype,
const paddle::DataType& value_cache_dtype,
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& batch_id_per_token_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& block_tables_dtype,
const paddle::DataType& encoder_batch_ids_dtype,
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
const paddle::DataType& encoder_num_blocks_dtype,
const paddle::DataType& kv_batch_ids_dtype,
const paddle::DataType& kv_tile_ids_per_batch_dtype,
const paddle::DataType& kv_num_blocks_dtype,
const paddle::DataType& decoder_batch_ids_dtype,
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
const paddle::DataType& decoder_num_blocks_dtype,
const paddle::DataType& set_max_lengths_dtype,
const paddle::DataType& max_len_kv_dtype,
const paddle::DataType& fmha_out_dtype,
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
const paddle::optional<paddle::DataType>& attn_mask_dtype,
const paddle::optional<paddle::DataType>& qkv_bias_dtype,
const paddle::optional<paddle::DataType>& qkv_out_scales_dtype,
const paddle::optional<paddle::DataType>& cache_k_quant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_v_quant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_k_dequant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_v_dequant_scales_dtype,
const paddle::optional<paddle::DataType>& cache_k_zp_dtype,
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
const paddle::optional<paddle::DataType>& mask_offset_dtype,
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
const float rms_norm_eps,
const std::string& compute_dtype,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_input_length,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
return {fmha_out_dtype};
}
PD_BUILD_STATIC_OP(append_attention)
.Inputs({"qkv",
"key_cache",
@@ -1059,15 +714,11 @@ PD_BUILD_STATIC_OP(append_attention)
paddle::Optional("cache_v_zp"),
paddle::Optional("out_linear_shifts"),
paddle::Optional("out_linear_smooths"),
paddle::Optional("mask_offset"),
paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
.Outputs({"fmha_out", "key_cache_out", "value_cache_out"})
paddle::Optional("kv_signal_data")})
.Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})
.SetInplaceMap({{"key_cache", "key_cache_out"},
{"value_cache", "value_cache_out"}})
.Attrs({"rms_norm_eps: float",
"compute_type: std::string",
.Attrs({"compute_type: std::string",
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"rope_3d: bool",
@@ -1081,71 +732,7 @@ PD_BUILD_STATIC_OP(append_attention)
"encoder_max_partition_size: int",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool",
})
"speculate_decoder: bool"})
.SetKernelFn(PD_KERNEL(AppendAttention))
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));
PD_BUILD_STATIC_OP(append_attention_with_output)
.Inputs({"qkv",
"key_cache",
"value_cache",
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"batch_id_per_token",
"cu_seqlens_q",
"block_tables",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",
"encoder_num_blocks",
"kv_batch_ids",
"kv_tile_ids_per_batch",
"kv_num_blocks",
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks",
"set_max_lengths",
"max_len_kv",
"fmha_out",
paddle::Optional("rotary_embs"),
paddle::Optional("attn_mask"),
paddle::Optional("qkv_bias"),
paddle::Optional("qkv_out_scales"),
paddle::Optional("cache_k_quant_scales"),
paddle::Optional("cache_v_quant_scales"),
paddle::Optional("cache_k_dequant_scales"),
paddle::Optional("cache_v_dequant_scales"),
paddle::Optional("cache_k_zp"),
paddle::Optional("cache_v_zp"),
paddle::Optional("out_linear_shifts"),
paddle::Optional("out_linear_smooths"),
paddle::Optional("mask_offset"),
paddle::Optional("kv_signal_data"),
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight")})
.Outputs({"fmha_out_out", "qkv_out", "key_cache_out", "value_cache_out"})
.SetInplaceMap({{"fmha_out", "fmha_out_out"},
{"key_cache", "key_cache_out"},
{"value_cache", "value_cache_out"}})
.Attrs({"rms_norm_eps: float",
"compute_type: std::string",
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"rope_3d: bool",
"max_input_length: int",
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"max_partition_size: int",
"encoder_max_partition_size: int",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool",
})
.SetKernelFn(PD_KERNEL(AppendAttentionWithOutput))
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionWithOutputInferDtype));

View File

@@ -43,7 +43,6 @@ __global__ void multi_query_append_attention_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -52,7 +51,6 @@ __global__ void multi_query_append_attention_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
@@ -75,11 +73,6 @@ __global__ void multi_query_append_attention_kernel(
block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
@@ -148,7 +141,6 @@ __global__ void multi_query_append_attention_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -187,7 +179,7 @@ __global__ void multi_query_append_attention_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: chunk_len) /
(num_frags_z * 16);
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
8 * (tid / 16) + tid % 8, (tid % 16) / 8);
@@ -253,16 +245,12 @@ __global__ void multi_query_append_attention_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(nullptr,
q_base_seq_id_this_block,
num_frags_z>(q_base_seq_id_this_block,
kv_idx_base,
q_len,
kv_len,
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
s_frag);
}
// update m,d
@@ -418,8 +406,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -428,14 +414,12 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const int32_t attn_mask_len = -1) {
const int speculate_max_draft_token_num = 5) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
@@ -452,11 +436,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
@@ -523,7 +502,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -561,9 +540,10 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len,
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -631,15 +611,12 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
num_frags_z>(q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
s_frag);
}
// update m,d
@@ -905,7 +882,6 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -914,7 +890,6 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
@@ -964,7 +939,6 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -973,7 +947,6 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
@@ -1088,18 +1061,12 @@ void MultiQueryAppendAttention(
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
const int num_chunks = div_up(max_dec_len, chunk_size);
int32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {
attn_mask_len = -1;
}
const int num_chunks = div_up(max_seq_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps);
if (num_chunks <= 0) {
if (num_chunks <= 1) {
auto nosplit_kv_kernel =
multi_query_append_attention_warp1_4_kernel<NV_TYPE,
false,
@@ -1137,9 +1104,6 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1148,13 +1112,11 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
speculate_max_draft_token_num);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1199,8 +1161,8 @@ void MultiQueryAppendAttention(
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
@@ -1210,9 +1172,6 @@ void MultiQueryAppendAttention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1221,13 +1180,11 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
speculate_max_draft_token_num);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
@@ -1251,8 +1208,8 @@ void MultiQueryAppendAttention(
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
@@ -1269,14 +1226,14 @@ void MultiQueryAppendAttention(
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads);
num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
@@ -1287,8 +1244,8 @@ void MultiQueryAppendAttention(
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,

View File

@@ -48,7 +48,6 @@ __global__ void multi_query_append_attention_c4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -57,7 +56,6 @@ __global__ void multi_query_append_attention_c4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
@@ -86,11 +84,6 @@ __global__ void multi_query_append_attention_c4_kernel(
block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
@@ -179,7 +172,6 @@ __global__ void multi_query_append_attention_c4_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -256,7 +248,7 @@ __global__ void multi_query_append_attention_c4_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: chunk_len) /
(num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -341,15 +333,12 @@ __global__ void multi_query_append_attention_c4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(nullptr,
q_base_seq_id_this_block,
num_frags_z>(q_base_seq_id_this_block,
kv_idx_base,
q_len,
kv_len,
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
s_frag);
}
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
@@ -516,8 +505,6 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -526,14 +513,12 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const int32_t attn_mask_len = -1) {
const int speculate_max_draft_token_num = 5) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
@@ -556,11 +541,6 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
@@ -647,7 +627,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -723,9 +703,10 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len,
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -807,15 +788,12 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
num_frags_z>(q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
s_frag);
}
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
@@ -1110,7 +1088,6 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1119,7 +1096,6 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
@@ -1175,7 +1151,6 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1184,7 +1159,6 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
@@ -1311,18 +1285,10 @@ void MultiQueryAppendC4Attention(
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
const int num_chunks = div_up(max_seq_len, chunk_size);
int32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {
attn_mask_len = -1;
}
const int num_chunks = div_up(max_dec_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps);
if (num_chunks <= 0) {
if (num_chunks <= 1) {
auto nosplit_kv_kernel =
multi_query_append_attention_c4_warp1_4_kernel<NV_TYPE,
uint8_t,
@@ -1368,9 +1334,6 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1379,13 +1342,11 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
speculate_max_draft_token_num);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1431,15 +1392,15 @@ void MultiQueryAppendC4Attention(
const_cast<uint8_t *>(cache_v.data<uint8_t>()),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
cache_k_zp ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(cache_k_zp.get().data<T>()))
: nullptr,
const_cast<T *>(cache_k_zp.get().data<T>()))
: nullptr,
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
cache_v_zp ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(cache_v_zp.get().data<T>()))
: nullptr,
const_cast<T *>(cache_v_zp.get().data<T>()))
: nullptr,
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
@@ -1449,9 +1410,6 @@ void MultiQueryAppendC4Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1460,13 +1418,11 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
speculate_max_draft_token_num);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {
@@ -1489,8 +1445,8 @@ void MultiQueryAppendC4Attention(
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
@@ -1507,14 +1463,14 @@ void MultiQueryAppendC4Attention(
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads);
num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
@@ -1525,8 +1481,8 @@ void MultiQueryAppendC4Attention(
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,

View File

@@ -32,15 +32,14 @@ template <typename T,
typename OutT = T,
bool ENABLE_PREFILL = true,
bool is_scale_channel_wise = false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
bool IsFP8=false>
__global__ void multi_query_append_attention_c8_kernel(
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
CacheT *__restrict__ cache_v,
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ cache_k_scale, // [num_kv_heads]
const T *__restrict__ cache_v_scale, // [num_kv_heads]
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const int *__restrict__ seq_lens,
@@ -49,7 +48,6 @@ __global__ void multi_query_append_attention_c8_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -58,7 +56,6 @@ __global__ void multi_query_append_attention_c8_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
@@ -88,40 +85,33 @@ __global__ void multi_query_append_attention_c8_kernel(
block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
}
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
}
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
}
} else {
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
T cache_k_scale_reg[num_frags_y * 4];
T cache_v_scale_reg[num_frags_y * 2];
if (is_scale_channel_wise) {
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
}
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
}
} else {
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
}
const uint32_t q_end =
@@ -189,7 +179,6 @@ __global__ void multi_query_append_attention_c8_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -210,13 +199,6 @@ __global__ void multi_query_append_attention_c8_kernel(
smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
T* k_smem_scale = nullptr;
T* v_smem_scale = nullptr;
if constexpr (IsDynamicC8) {
k_smem_scale = reinterpret_cast<T*>(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
v_smem_scale = k_smem_scale + num_frags_z * 16;
}
const uint32_t num_iterations = div_up(
@@ -234,7 +216,7 @@ __global__ void multi_query_append_attention_c8_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: chunk_len) /
(num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -298,22 +280,10 @@ __global__ void multi_query_append_attention_c8_kernel(
#pragma unroll 1
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
if constexpr (IsDynamicC8) {
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
k_smem_scale,
cache_k_scale_reg,
block_table_now,
cache_k_scale,
kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
}
wait_group<1>();
__syncthreads();
// s = qk
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
&qo_smem,
&q_smem_offset_r,
&k_smem,
@@ -330,15 +300,12 @@ __global__ void multi_query_append_attention_c8_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(nullptr,
q_base_seq_id_this_block,
num_frags_z>(q_base_seq_id_this_block,
kv_idx_base,
q_len,
kv_len,
chunk_end,
-1,
s_frag,
mask_offset_this_seq);
s_frag);
}
// update m,d
@@ -346,7 +313,6 @@ __global__ void multi_query_append_attention_c8_kernel(
s_frag, o_frag, m_frag, d_frag);
__syncthreads();
const int ori_kv_idx_base = kv_idx_base;
kv_idx_base += num_frags_z * 16;
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
NUM_WARPS,
@@ -365,18 +331,6 @@ __global__ void multi_query_append_attention_c8_kernel(
chunk_end,
const_k_offset);
commit_group();
if constexpr (IsDynamicC8) {
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
v_smem_scale,
cache_v_scale_reg,
block_table_now,
cache_v_scale,
ori_kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
}
wait_group<1>();
__syncthreads();
@@ -387,9 +341,7 @@ __global__ void multi_query_append_attention_c8_kernel(
BLOCK_SIZE,
T,
CacheT,
is_scale_channel_wise,
IsFP8,
IsDynamicC8>(
is_scale_channel_wise, IsFP8>(
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
__syncthreads();
@@ -506,15 +458,14 @@ template <typename T,
typename OutT = T,
bool ENABLE_PREFILL = true,
bool is_scale_channel_wise=false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
bool IsFP8=false>
__global__ void multi_query_append_attention_c8_warp1_4_kernel(
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
CacheT *__restrict__ cache_v,
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim]
const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim]
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const int *__restrict__ seq_lens,
@@ -523,8 +474,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -533,14 +482,12 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
OutT *__restrict__ out,
const int speculate_max_draft_token_num = 5,
const int32_t attn_mask_len = -1) {
const int speculate_max_draft_token_num = 5) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t num_vecs_per_head_k =
HEAD_DIM / num_elems_per_128b<CacheT>();
@@ -563,39 +510,32 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
}
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
}
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
}
} else {
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
T cache_k_scale_reg[num_frags_y * 4];
T cache_v_scale_reg[num_frags_y * 2];
if (is_scale_channel_wise) {
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
}
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
const int scale_idx = i * 16;
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
}
} else {
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
}
const uint32_t q_end =
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
@@ -661,7 +601,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -686,13 +626,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
T* k_smem_scale = nullptr;
T* v_smem_scale = nullptr;
if constexpr (IsDynamicC8) {
k_smem_scale = reinterpret_cast<T*>(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
v_smem_scale = k_smem_scale + NUM_WARP_KV * num_frags_z * 16;
}
const uint32_t num_iterations = div_up(
CAUSAL
@@ -709,7 +642,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -775,23 +708,11 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
commit_group();
#pragma unroll 1
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
if constexpr (IsDynamicC8) {
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
k_smem_scale,
cache_k_scale_reg,
block_table_now,
cache_k_scale,
kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
}
wait_group<1>();
__syncthreads();
// s = qk
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
&qo_smem,
&q_smem_offset_r,
&k_smem,
@@ -807,16 +728,12 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
num_frags_z>(q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq);
s_frag);
}
// update m,d
@@ -824,7 +741,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
s_frag, o_frag, m_frag, d_frag);
__syncthreads();
const uint32_t ori_kv_idx_base = kv_idx_base;
kv_idx_base += NUM_WARP_KV * num_frags_z * 16;
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
NUM_WARPS,
@@ -843,18 +759,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
chunk_end,
const_k_offset);
commit_group();
if constexpr (IsDynamicC8) {
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
v_smem_scale,
cache_v_scale_reg,
block_table_now,
cache_v_scale,
ori_kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
}
wait_group<1>();
__syncthreads();
@@ -865,9 +769,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
BLOCK_SIZE,
T,
CacheT,
is_scale_channel_wise,
IsFP8,
IsDynamicC8>(
is_scale_channel_wise, IsFP8>(
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
__syncthreads();
@@ -981,8 +883,7 @@ template <typename T,
uint32_t NUM_WARP_Q,
typename OutT = T,
bool ENABLE_PREFILL = true,
bool IsFP8 = false,
bool IsDynamicC8 = false>
bool IsFP8=false>
void MultiQueryAppendC8Attention(
const AppendAttnMetaData &meta_data,
const paddle::Tensor &qkv,
@@ -1040,8 +941,7 @@ void MultiQueryAppendC8Attention(
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16;
constexpr uint32_t smem_size =
num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
num_frags_z * 16 * sizeof(T) * 2;
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
auto split_kv_kernel =
multi_query_append_attention_c8_kernel<NV_TYPE,
uint8_t,
@@ -1058,9 +958,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
false,
IsFP8,
IsDynamicC8>;
false, IsFP8>;
if (is_scale_channel_wise) {
split_kv_kernel =
multi_query_append_attention_c8_kernel<NV_TYPE,
@@ -1078,9 +976,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
true, IsFP8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(split_kv_kernel,
@@ -1114,9 +1010,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
false,
IsFP8,
IsDynamicC8>;
false, IsFP8>;
if (is_scale_channel_wise) {
nosplit_kv_kernel =
multi_query_append_attention_c8_kernel<NV_TYPE,
@@ -1134,9 +1028,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
true, IsFP8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(nosplit_kv_kernel,
@@ -1162,7 +1054,6 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1171,7 +1062,6 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
@@ -1221,7 +1111,6 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1230,7 +1119,6 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
@@ -1316,8 +1204,7 @@ void MultiQueryAppendC8Attention(
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2;
constexpr uint32_t smem_size =
num_frags_x * 16 * HEAD_DIM * sizeof(T) +
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2;
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
auto split_kv_kernel =
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
uint8_t,
@@ -1334,9 +1221,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
false,
IsFP8,
IsDynamicC8>;
false, IsFP8>;
if (is_scale_channel_wise) {
split_kv_kernel =
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
@@ -1354,9 +1239,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
true, IsFP8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(split_kv_kernel,
@@ -1371,17 +1254,10 @@ void MultiQueryAppendC8Attention(
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
const int num_chunks = div_up(max_seq_len, chunk_size);
int32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
} else {
attn_mask_len = -1;
}
const int num_chunks = div_up(max_dec_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps);
if (num_chunks <= 0) {
if (num_chunks <= 1) {
auto nosplit_kv_kernel =
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
uint8_t,
@@ -1398,9 +1274,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
false,
IsFP8,
IsDynamicC8>;
false, IsFP8>;
if (is_scale_channel_wise) {
nosplit_kv_kernel =
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
@@ -1418,9 +1292,7 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
true, IsFP8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(nosplit_kv_kernel,
@@ -1446,9 +1318,6 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1457,13 +1326,11 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
speculate_max_draft_token_num);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
@@ -1510,8 +1377,8 @@ void MultiQueryAppendC8Attention(
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
@@ -1521,9 +1388,6 @@ void MultiQueryAppendC8Attention(
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1532,13 +1396,11 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num,
attn_mask_len);
speculate_max_draft_token_num);
// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {
@@ -1556,8 +1418,8 @@ void MultiQueryAppendC8Attention(
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
@@ -1574,14 +1436,14 @@ void MultiQueryAppendC8Attention(
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads);
num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
@@ -1592,8 +1454,8 @@ void MultiQueryAppendC8Attention(
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
@@ -1655,7 +1517,6 @@ void CascadeAppendAttentionC8Kernel(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out) {
const auto token_num = meta_data.token_nums;
@@ -1664,7 +1525,6 @@ void CascadeAppendAttentionC8Kernel(
const auto num_heads = meta_data.q_num_heads;
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
const auto head_dim = meta_data.head_dims;
bool is_dynamic_cfp8 = cache_quant_type_str == "block_wise_fp8";
DISPATCH_CAUSAL(
causal,
@@ -1683,46 +1543,43 @@ void CascadeAppendAttentionC8Kernel(
BLOCK_SIZE,
{DISPATCH_BLOCKSHAPE_Q(
block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, {
DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, {
MultiQueryAppendC8Attention<T,
GROUP_SIZE,
HEAD_DIM,
BLOCK_SIZE,
CAUSAL,
BLOCK_SHAPE_Q,
NUM_WARP_Q,
OutT,
ENABLE_PREFILL,
IsFP8,
IsDynamicC8>(
meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale.get(),
cache_v_scale.get(),
shift_bias,
smooth_weight,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
is_decoder,
stream,
out);
})})})})})})})
MultiQueryAppendC8Attention<T,
GROUP_SIZE,
HEAD_DIM,
BLOCK_SIZE,
CAUSAL,
BLOCK_SHAPE_Q,
NUM_WARP_Q,
OutT,
ENABLE_PREFILL, IsFP8>(
meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale.get(),
cache_v_scale.get(),
shift_bias,
smooth_weight,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
is_decoder,
stream,
out);
})})})})})})
}

View File

@@ -384,113 +384,6 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
}
}
template<uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_k_dynamic_scale(
T* k_smem_scale,
T* cache_k_reg,
const int* block_table_now,
const T* cache_k_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end
) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
if (tid < block_size) {
k_smem_scale[tid] = cache_k_scale_now[tid];
}
__syncthreads();
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id];
cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8];
}
} else {
// 1 warp 32 tokens
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
} else {
k_smem_scale[ty * 32 + tx] = 0;
}
__syncwarp();
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id];
cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8];
}
}
}
template<uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_v_dynamic_scale(
T* v_smem_scale,
T* cache_v_reg,
const int* block_table_now,
const T* cache_v_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end
) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
if (tid < block_size) {
v_smem_scale[tid] = cache_v_scale_now[tid];
}
__syncthreads();
const uint32_t row_id = tx % 4 * 2;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id];
cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9];
}
} else {
// 1 warp 32 tokens
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
} else {
v_smem_scale[ty * 32 + tx] = 0;
}
__syncwarp();
const uint32_t row_id = tx % 4 * 2;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id];
cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9];
}
}
}
template <SharedMemFillMode fill_mode,
uint32_t num_warps,
uint32_t block_size,
@@ -923,8 +816,7 @@ template <uint32_t num_frags_x,
typename T,
typename CacheT,
bool is_scale_channel_wise = false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
bool IsFP8=false>
__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
uint32_t* q_smem_offset_r,
smem_t* k_smem,
@@ -968,27 +860,20 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
const int scale_col = (ky * 2 + fy) * 4;
b_frag_dq_T[0] *= cache_k_scale[scale_col];
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
b_frag_dq_T[4] *= cache_k_scale[scale_col];
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
} else {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_k_scale[0];
}
}
if constexpr (is_scale_channel_wise) {
const int scale_col = (ky * 2 + fy) * 4;
b_frag_dq_T[0] *= cache_k_scale[scale_col];
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
b_frag_dq_T[4] *= cache_k_scale[scale_col];
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
} else {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4];
b_frag_dq_T[b_i] *= cache_k_scale[0];
}
}
#pragma unroll
@@ -1020,15 +905,12 @@ template <typename T,
uint32_t num_frags_y,
uint32_t num_frags_z,
bool IS_SYSTEM = false>
__device__ __forceinline__ void mask_s(const bool* attn_mask,
const uint32_t qo_idx_base,
__device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
const uint32_t kv_idx_base,
const uint32_t qo_len,
const uint32_t kv_len,
const uint32_t chunk_end,
const int32_t attn_mask_len,
float (*s_frag)[num_frags_z][8],
const int *mask_offset = nullptr) {
float (*s_frag)[num_frags_z][8]) {
const uint32_t tx = threadIdx.x;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
@@ -1042,21 +924,10 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
group_size,
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
8 * (reg_id / 4) + reg_id % 2;
bool out_of_boundary;
if (mask_offset) {
out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true;
} else {
out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && attn_mask_len > 0 && q_idx < static_cast<uint32_t>(attn_mask_len)) {
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
bool mask = attn_mask[mask_idx];
out_of_boundary |= mask;
}
}
const bool out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
if constexpr (std::is_same<T, half>::value) {
s_frag[fx][fz][reg_id] =
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];
@@ -1064,7 +935,6 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
s_frag[fx][fz][reg_id] =
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
}
// printf("tid: %d. qk[%u,%u] = %f, mask: %d \n ", threadIdx.x, kv_idx, q_idx, static_cast<float>(s_frag[fx][fz][reg_id]), int(out_of_boundary));
} else {
const uint32_t q_idx = qo_idx_base,
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
@@ -1208,9 +1078,7 @@ template <uint32_t num_frags_x,
uint32_t block_size,
typename T,
typename CacheT,
bool is_scale_channel_wise = false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
bool is_scale_channel_wise = false, bool IsFP8=false>
__device__ __forceinline__ void compute_sfm_v_c8(
smem_t* v_smem,
uint32_t* v_smem_offset_r,
@@ -1252,28 +1120,16 @@ __device__ __forceinline__ void compute_sfm_v_c8(
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
if constexpr (is_scale_channel_wise) {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
}
} else {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[0];
}
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
}
} else {
const int scale_col = (kz * 2 + fz) * 4;
b_frag_dq_T[0] *= cache_v_scale[scale_col];
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
b_frag_dq_T[4] *= cache_v_scale[scale_col];
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[0];
}
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
@@ -1300,9 +1156,7 @@ template <uint32_t num_frags_x,
uint32_t block_size,
typename T,
typename CacheT,
bool is_scale_channel_wise = false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
bool is_scale_channel_wise = false, bool IsFP8=false>
__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
smem_t* v_smem,
uint32_t* v_smem_offset_r,
@@ -1346,28 +1200,16 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
if constexpr (is_scale_channel_wise) {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
}
} else {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[0];
}
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
}
} else {
const int scale_col = (kz * 2 + fz) * 4;
b_frag_dq_T[0] *= cache_v_scale[scale_col];
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
b_frag_dq_T[4] *= cache_v_scale[scale_col];
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[0];
}
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16

View File

@@ -103,7 +103,6 @@ void CascadeAppendAttentionC8Kernel(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -265,10 +264,9 @@ void CascadeAppendAttentionKernel(
causal,
is_decoder,
enable_prefill,
cache_quant_type_str,
stream,
out);
} else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
} else if (cache_quant_type_str == "cache_fp8") {
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
qkv,
cache_k,
@@ -301,7 +299,6 @@ void CascadeAppendAttentionKernel(
causal,
is_decoder,
enable_prefill,
cache_quant_type_str,
stream,
out);
} else if (cache_quant_type_str == "cache_int4_zp") {

View File

@@ -15,73 +15,13 @@
#include "decoder_write_cache_with_rope_kernel.h"
#include "utils.cuh"
template <typename T, typename QKV_TYPE>
void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
T* key_cache,
T* value_cache,
T* qkv_out,
const int* block_tables,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
const float* sin_emb,
const float* qkv_out_scales,
const T* qkv_biases,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int kv_num_heads,
const int dim_head,
const int block_size,
const int bsz,
const cudaStream_t& stream,
const bool use_neox_style,
const bool rope_3d,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps) {
const uint32_t elem_nums =
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
: bsz * (num_heads + 2 * kv_num_heads) * dim_head;
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize>
<<<grid_size, block_dim, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d,
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
template <typename T, typename QKV_TYPE>
void append_decode_cache_rope(const QKV_TYPE* qkv,
T* key_cache,
T* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
@@ -94,7 +34,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
const int num_heads,
const int kv_num_heads,
const int dim_head,
const int rotary_dim,
const int block_size,
const int bsz,
const cudaStream_t& stream,
@@ -118,6 +57,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -131,37 +71,15 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
if (rotary_dim < dim_head){
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
rotary_dim,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
}else{
append_decode_cache_T_neox_rope_kernel<T, PackSize>
append_decode_cache_T_neox_rope_kernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -173,9 +91,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
}
kv_num_heads);
}
} else {
if (qkv_out_scales) {
@@ -186,6 +102,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -208,6 +125,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -231,6 +149,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
uint8_t* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
@@ -263,6 +182,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -278,8 +198,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_decode_cache_int8_neox_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(
@@ -288,6 +207,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -301,8 +221,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
}
} else {
if (qkv_out_scales) {
@@ -313,6 +232,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -328,8 +248,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_decode_cache_int8_rope_kernel<T, 4, 0, 128, is_scale_channel_wise, IsFP8>
<<<grids, num_warps * 32, 0, stream>>>(
@@ -338,6 +257,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -351,8 +271,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
}
}
}
@@ -363,6 +282,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
uint8_t* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
@@ -397,6 +317,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -414,8 +335,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_decode_cache_int4_neox_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(
@@ -424,6 +344,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -439,8 +360,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
}
} else {
if (qkv_out_scales) {
@@ -451,6 +371,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -468,8 +389,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_decode_cache_int4_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(
@@ -478,6 +398,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -493,8 +414,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
}
}
}
@@ -504,6 +424,7 @@ void DecoderWriteCacheWithRoPEKernel(
const paddle::Tensor& qkv,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -520,10 +441,7 @@ void DecoderWriteCacheWithRoPEKernel(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps) {
paddle::Tensor* value_cache_out) {
typedef cascade_attn_type_traits<T> traits_;
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
typedef typename traits_::type DataType_;
@@ -540,30 +458,85 @@ void DecoderWriteCacheWithRoPEKernel(
const float* cos_emb =
rotary_embs ? rotary_embs.get().data<float>() : nullptr;
const float* sin_emb;
int rotary_dim = dim_head;
if (rotary_embs) {
sin_emb =
use_neox_rotary_style
? rotary_embs.get().data<float>() + max_seq_len * dim_head
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
if(rotary_dim < dim_head){
if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight || k_norm_weight|| cache_quant_type_str != "none"){
PADDLE_THROW(phi::errors::Fatal(
"partial_rotary_factor < 1.0 only supports neox_rotary_style=True, qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and cache_quant_type_str is 'none'."));
}
sin_emb = rotary_embs.get().data<float>() + max_seq_len * rotary_dim / 2;
}
}
if (q_norm_weight && k_norm_weight) {
if (cache_quant_type_str == "none") {
append_decode_cache_rope_qk_norm(
if (cache_quant_type_str == "none") {
append_decode_cache_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int8") {
bool is_scale_channel_wise = false;
if (cache_k_scale && cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) {
is_scale_channel_wise = true;
}
if (is_scale_channel_wise) {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -573,6 +546,12 @@ void DecoderWriteCacheWithRoPEKernel(
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
@@ -582,247 +561,84 @@ void DecoderWriteCacheWithRoPEKernel(
bsz,
stream,
use_neox_rotary_style,
rope_3d,
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
rms_norm_eps);
} else if (cache_quant_type_str == "block_wise_fp8") {
constexpr int num_warps = 4;
const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
dim3 grids(bsz, all_warps / num_warps);
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
q_norm_weight.get().data<float>(),
k_norm_weight.get().data<float>(),
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d,
rms_norm_eps);
} else {
PD_THROW(
"append_decode_cache_rope_qk_norm just supports cache_quant_type none/block_wise_fp8");
rope_3d);
}
} else if (cache_quant_type_str == "cache_fp8") {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_decode_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else {
if (cache_quant_type_str == "none") {
append_decode_cache_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
rotary_dim,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int8") {
bool is_scale_channel_wise = false;
if (cache_k_scale && cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) {
is_scale_channel_wise = true;
}
if (is_scale_channel_wise) {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
}
} else if (cache_quant_type_str == "cache_fp8") {
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "block_wise_fp8") {
constexpr int num_warps = 4;
const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
dim3 grids(bsz, all_warps / num_warps);
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
nullptr,
nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d,
rms_norm_eps);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_decode_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
stream,
use_neox_rotary_style,
rope_3d);
} else {
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, cache_fp8 "
"cache_int4_zp]");
}
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, cache_fp8 "
"cache_int4_zp]");
}
}
@@ -834,6 +650,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -850,10 +667,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);
template void
DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
@@ -863,6 +677,7 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -879,10 +694,7 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
const AppendAttnMetaData& meta_data,
@@ -891,6 +703,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -907,10 +720,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const AppendAttnMetaData& meta_data,
@@ -919,6 +729,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -935,7 +746,4 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -23,6 +23,7 @@ void DecoderWriteCacheWithRoPEKernel(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -39,6 +40,4 @@ void DecoderWriteCacheWithRoPEKernel(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight, const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -46,32 +46,38 @@ void EncoderWriteCacheWithRopeKernel(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps) {
paddle::Tensor* value_cache_out) {
auto token_num = meta_data.token_nums;
auto num_heads = meta_data.q_num_heads;
auto kv_num_heads = meta_data.kv_num_heads;
auto head_dim = meta_data.head_dims;
bool is_scale_channel_wise = false;
int rotary_dim = head_dim;
if (cache_k_scale && cache_k_scale.get().dims()[0] == head_dim * kv_num_heads) {
is_scale_channel_wise = true;
}
if (rotary_embs){
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
if(rotary_dim < head_dim){
if (!use_neox_style || q_norm_weight || k_norm_weight || num_heads == kv_num_heads || is_scale_channel_wise){
PADDLE_THROW(phi::errors::Fatal(
"partial_rotary_factor < 1.0 only supports use_neox_rotary_style=True, q_norm_weight/k_norm_weight) is None, GQA and is_scale_channel_wise=false."));
}
}
}
if (q_norm_weight && k_norm_weight) {
if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) {
gqa_rotary_qk_norm_variable(
if (num_heads == kv_num_heads) {
rotary_qk_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
max_seq_len,
rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d);
} else {
if (!is_scale_channel_wise) {
gqa_rotary_qk_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
@@ -89,81 +95,31 @@ void EncoderWriteCacheWithRopeKernel(
head_dim,
stream,
use_neox_style,
rope_3d,
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
rms_norm_eps);
rope_3d);
} else {
PD_THROW(
"gqa_rotary_qk_norm_variable only support gqa mode. channel wise scale and neox style are not supported");
gqa_rotary_qk_quant_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
cache_k_scale ? cache_k_scale.get().data<T>() : nullptr,
cache_v_scale ? cache_v_scale.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d);
}
} else {
if (num_heads == kv_num_heads) {
rotary_qk_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
max_seq_len,
rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d);
} else {
if (!is_scale_channel_wise) {
gqa_rotary_qk_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2],
head_dim,
rotary_dim,
stream,
use_neox_style,
rope_3d);
} else {
gqa_rotary_qk_quant_variable(
qkv_out->data<T>(),
qkv.data<QKV_TYPE>(),
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
cache_k_scale ? cache_k_scale.get().data<T>() : nullptr,
cache_v_scale ? cache_v_scale.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
rotary_embs.get().dims()[2],
head_dim,
stream,
use_neox_style,
rope_3d);
}
}
}
const uint32_t block_size = meta_data.block_size;
if (cache_quant_type_str == "none") {
@@ -178,7 +134,7 @@ void EncoderWriteCacheWithRopeKernel(
stream,
key_cache_out,
value_cache_out);
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8") {
DISPATCH_HEAD_DIM(
head_dim, HEAD_DIM, {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {
CascadeAppendWriteCacheKVC8QKV<T, HEAD_DIM, BLOCK_SIZE>(
@@ -198,7 +154,7 @@ void EncoderWriteCacheWithRopeKernel(
num_blocks,
max_seq_len,
is_scale_channel_wise,
cache_quant_type_str,
cache_quant_type_str == "cache_fp8",
stream,
key_cache_out,
value_cache_out);

View File

@@ -11,11 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cute/tensor.hpp"
#include "helper.h"
#include "paddle/extension.h"
#include "paddle/phi/core/memory/memcpy.h"
#include "utils.cuh"
template <int THREADBLOCK_SIZE>
__global__ void
@@ -117,93 +116,6 @@ void GetMaxLen(const paddle::Tensor &seq_lens_tensor,
max_len_tensor.data<int>(), batch_size);
}
template <uint32_t config_size>
__global__ void search_chunk_size_for_mla(
const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ seq_lens_decoder,
int *__restrict__ num_blocks_x,
int *__restrict__ res_chunk_size,
const int bsz,
const int set_chunk_size,
const int block_size,
const int sm_cout) {
const uint32_t conf_id = threadIdx.x;
int gridx = 0;
if (set_chunk_size > 0 && conf_id == 0) {
for (uint32_t bid = 0; bid < bsz; bid++) {
int seq_len = seq_lens_q[bid];
int seq_len_encoder = seq_lens_encoder[bid];
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
if (seq_len == 0 || seq_len_encoder > 0) continue;
int loop_times;
loop_times = cute::ceil_div(seq_len_decoder, set_chunk_size);
gridx += loop_times;
}
*num_blocks_x = gridx;
*res_chunk_size = set_chunk_size;
} else if (conf_id < config_size) {
__shared__ int gridx_shared[config_size];
// chunk_size is a multiple of 64
const int chunk_size = block_size << conf_id;
for (uint32_t bid = 0; bid < bsz; bid++) {
int seq_len = seq_lens_q[bid];
int seq_len_encoder = seq_lens_encoder[bid];
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
if (seq_len == 0 || seq_len_encoder > 0) continue;
int loop_times;
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
gridx += loop_times;
}
gridx_shared[conf_id] = gridx;
__syncthreads();
if (threadIdx.x == 0) {
uint32_t res_id = 0;
uint32_t max_last_wave_block = 0;
for (uint32_t i = 1; i < config_size; i++) {
uint32_t last_wave_block = gridx_shared[i] % sm_cout;
if (last_wave_block >= max_last_wave_block) {
res_id = i;
max_last_wave_block = last_wave_block;
}
}
*num_blocks_x = gridx_shared[res_id];
*res_chunk_size = block_size << res_id;
}
}
}
__global__ void split_block_for_mla(const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ seq_lens_decoder,
int *__restrict__ batch_ids,
int *__restrict__ tile_ids_per_batch,
const int bsz,
const int chunk_size) {
if (threadIdx.x == 0) {
int index = 0;
for (uint32_t bid = 0; bid < bsz; bid++) {
int seq_len = seq_lens_q[bid];
int seq_len_encoder = seq_lens_encoder[bid];
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
if (seq_len == 0) continue;
int loop_times;
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
if (seq_len_encoder > 0) {
loop_times = 0;
}
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
batch_ids[index] = bid;
tile_ids_per_batch[index++] = tile_id;
}
}
}
}
__global__ void split_q_block(const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_encoder,
int *__restrict__ batch_ids,
@@ -279,38 +191,26 @@ get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
}
}
void GetBlockShapeAndSplitKVBlock(
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
paddle::Tensor &decoder_batch_ids, // Inplace
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
paddle::Tensor &decoder_num_blocks_device, // Inplace
paddle::Tensor &decoder_chunk_size_device, // Inplace
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
paddle::Tensor &encoder_batch_ids, // Inplace
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU
paddle::Tensor &kv_batch_ids, // Inplace
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
paddle::Tensor &max_len_kv_cpu, // Inplace, CPU
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num)
{
const int encoder_block_shape_q, const int decoder_block_shape_q,
const int group_size, const int block_size,
const int decoder_step_token_num) {
auto stream = seq_lens_encoder.stream();
int bsz = seq_lens_this_time.shape()[0];
paddle::Tensor max_len_tensor_gpu = GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, paddle::DataType::INT32, seq_lens_this_time.place());
auto max_len_tensor =
GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place());
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
max_len_tensor_gpu, bsz);
max_len_tensor_cpu.copy_(max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
max_len_tensor, bsz);
auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
// max_len_this_time, max_enc_len_this_time, max_dec_len_this_time,
// max_enc_dec_len_this_time, max_just_dec_len_this_time,
// max_just_dec_merged_len_this_time, max_system_len,
// max_just_dec_len_without_system
auto max_len_cpu = max_len_tensor.copy_to(paddle::CPUPlace(), false);
auto max_len_cpu_ptr = max_len_cpu.data<int>();
int max_len_this_time = max_len_cpu_ptr[0];
int max_enc_len_this_time = max_len_cpu_ptr[1];
int max_dec_len_this_time = max_len_cpu_ptr[2];
@@ -320,120 +220,34 @@ void GetBlockShapeAndSplitKVBlock(
int max_system_len = max_len_cpu_ptr[6];
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
paddle::Tensor encoder_batch_ids;
paddle::Tensor encoder_tile_ids_per_batch;
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
paddle::Tensor kv_batch_ids;
paddle::Tensor kv_tile_ids_per_batch;
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
paddle::Tensor decoder_batch_ids;
paddle::Tensor decoder_tile_ids_per_batch;
paddle::Tensor decoder_num_blocks_x_cpu; /*cpu*/
paddle::Tensor max_len_kv_cpu; /*cpu*/
auto max_len_kv =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(), bsz);
max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false);
max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);
// decoder
if (max_dec_len_this_time > 0) {
const bool mla_use_tensorcore = GetMlaUseTensorcore();
if (mla_use_tensorcore && group_size <= 64) {
const int set_chunk_size = get_mla_dec_chunk_size(bsz);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
int device;
cudaGetDevice(&device);
int sm_cout;
cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device);
constexpr int config_size =
12; // search space for chunk size:[64, 128, 256, ... 131072]
search_chunk_size_for_mla<config_size>
<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
decoder_num_blocks_device.data<int>(),
decoder_chunk_size_device.data<int>(),
bsz,
set_chunk_size,
block_size,
sm_cout);
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
auto decoder_chunk_size_cpu =
decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false);
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
// NOTE: (changwenbin) When using auto_chunk,
// decode_max_tile_size must take into account the maximum case, where * 1024 can cover 128K.
// const uint32_t decoder_batch_shape = seq_lens_decoder.dims()[0] * 1024;
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
const uint32_t decoder_batch_shape =
bsz * 1024 * decoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(decoder_batch_ids.data<int>(),
0,
decoder_batch_shape * sizeof(int32_t),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
0,
decoder_batch_shape * sizeof(int32_t),
stream));
split_block_for_mla<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
bsz,
chunk_size);
} else {
// Note:(changwenbin)In order to adapt to cudagraph, the maximum value should be taken here
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
const uint32_t decoder_batch_shape = bsz * 1024 * decoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
split_q_block<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_device.data<int>(),
bsz,
decoder_block_shape_q,
group_size);
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
}
} else {
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
}
// encoder
if (max_enc_len_this_time > 0) {
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
const uint32_t max_tile_size_per_bs_kv =
div_up(max_enc_dec_len_this_time, block_size);
kv_batch_ids =
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
seq_lens_encoder.place());
kv_tile_ids_per_batch =
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
seq_lens_encoder.place());
auto kv_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
@@ -444,12 +258,16 @@ void GetBlockShapeAndSplitKVBlock(
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
block_size, block_size);
kv_num_blocks_x_cpu.copy_(kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false);
// Clear buffer
const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);
const uint32_t encoder_max_tile_size_per_bs_q =
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
encoder_batch_ids =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32, seq_lens_encoder.place());
encoder_tile_ids_per_batch =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32, seq_lens_encoder.place());
auto encoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
@@ -457,38 +275,108 @@ void GetBlockShapeAndSplitKVBlock(
encoder_tile_ids_per_batch.data<int>(),
encoder_num_blocks_x.data<int>(), bsz,
encoder_block_shape_q, group_size);
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
encoder_num_blocks_x_cpu =
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
} else {
encoder_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
encoder_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
encoder_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
kv_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
kv_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
kv_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
}
if (max_just_dec_len_this_time > 0) {
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
decoder_batch_ids =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32, seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32, seq_lens_encoder.place());
auto decoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(), seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(), decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(), bsz, decoder_block_shape_q,
group_size);
decoder_num_blocks_x_cpu =
decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
} else {
decoder_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
decoder_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
}
return {encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks_x_cpu, /*cpu*/
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks_x_cpu, /*cpu*/
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks_x_cpu, /*cpu*/
max_len_kv_cpu /*cpu*/,
max_len_cpu};
}
std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
const paddle::DataType &seq_lens_encoder_dtype,
const paddle::DataType &seq_lens_decoder_dtype,
const paddle::DataType &seq_lens_this_time_dtype) {
return {
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
paddle::DataType::INT32, paddle::DataType::INT32};
}
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
const std::vector<int64_t> &seq_lens_encoder_shape,
const std::vector<int64_t> &seq_lens_decoder_shape,
const std::vector<int64_t> &seq_lens_this_time_shape) {
std::vector<int64_t> dynamic_shape = {-1};
return {dynamic_shape,
dynamic_shape,
{1},
dynamic_shape,
dynamic_shape,
{1},
dynamic_shape,
dynamic_shape,
{1},
{1},
{8}};
}
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
.Inputs({
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks_cpu",
"decoder_num_blocks_device",
"decoder_chunk_size_device",
"max_len_tensor_cpu",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",
"encoder_num_blocks_x_cpu",
"kv_batch_ids",
"kv_tile_ids_per_batch",
"kv_num_blocks_x_cpu",
"max_len_kv_cpu"
})
.Outputs({
})
.Attrs({
"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"group_size: int",
"block_size: int",
"decoder_step_token_num: int"
})
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock));
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"})
.Outputs({paddle::Optional("encoder_batch_ids"),
paddle::Optional("encoder_tile_ids_per_batch"),
paddle::Optional("encoder_num_blocks"),
paddle::Optional("kv_batch_ids"),
paddle::Optional("kv_tile_ids_per_batch"),
paddle::Optional("kv_num_blocks"),
paddle::Optional("decoder_batch_ids"),
paddle::Optional("decoder_tile_ids_per_batch"),
paddle::Optional("decoder_num_blocks"),
paddle::Optional("max_len_kv"), "set_max_lengths"})
.Attrs({"encoder_block_shape_q: int", "decoder_block_shape_q: int",
"group_size: int", "block_size: int",
"decoder_step_token_num: int"})
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype));

View File

@@ -37,8 +37,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int last_dim,
const bool rope_3d) {
const int last_dim) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
@@ -63,7 +62,6 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
const int64_t base_idx =
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
h_bias;
@@ -82,8 +80,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
Load<T, VecSize>(&qkv[base_idx], &src_vec);
// do rope
if (hi < q_num_head + kv_num_head) {
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
const float input_left = static_cast<float>(src_vec[2 * i]);
@@ -120,7 +118,6 @@ void gqa_rotary_qk_split_variable(
const int seq_len,
const int input_output_len,
const int dim_head,
const bool rope_3d,
const cudaStream_t &stream) {
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head;
constexpr int PackSize = 16 / sizeof(T);
@@ -149,8 +146,7 @@ void gqa_rotary_qk_split_variable(
num_heads,
kv_num_heads,
seq_len,
dim_head,
rope_3d);
dim_head);
}
template <typename T,
@@ -217,7 +213,7 @@ __global__ void append_cache_kv_c16(
// load k_smem 64 rows 128 cols
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
k_smem_offset_w =
@@ -235,7 +231,7 @@ __global__ void append_cache_kv_c16(
// deal k_smem 64 rows 128 cols
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
uint32_t row_idx = wid * 16 + tid / 4;
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
uint32_t col_idx = fy * 16 + tid % 4 * 2;
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
// layout
@@ -278,7 +274,7 @@ __global__ void append_cache_kv_c16(
// load v_smem 64 rows 128 cols
for (int fz = 0; fz < 4; fz++) { // // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
v_smem_offset_w =
@@ -296,7 +292,7 @@ __global__ void append_cache_kv_c16(
// deal v_smem 64 rows 128 cols
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
uint32_t row_idx = wid * 16 + tid / 4;
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
uint32_t col_idx = fy * 16 + tid % 4 * 2;
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag);
// layout
@@ -400,7 +396,7 @@ __global__ void append_cache_kv_c8(
// load v_smem 64 rows, 128 cols
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 once, need 1 iter
for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 noce, need 1 iter
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
k_smem_offset_w =
@@ -418,7 +414,7 @@ __global__ void append_cache_kv_c8(
// deal k_smem 64 rows, 128 cols
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
uint32_t row_idx = wid * 16 + tid / 4;
for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 once, need 4 iter
for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 noce, need 4 iter
uint32_t col_idx = fy * 32 + tid % 4 * 2;
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
// layout
@@ -466,7 +462,7 @@ __global__ void append_cache_kv_c8(
tid % 4 * num_elems_per_128b<CacheT>();
// load v_smem 128 rows 64 cols
for (int fy = 0; fy < 4; fy++) { // 8 rows pre warp once, 32 rows all 4 warps once, need 4 iter
for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 once, need 1 iter
for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 noce, need 1 iter
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
v_smem_offset_w =
@@ -485,7 +481,7 @@ __global__ void append_cache_kv_c8(
// deal v_smem 128 rows 64 cols
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 once, need 2 iter
for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 noce, need 2 iter
uint32_t kv_idx = fz * 32 + tid % 4 * 2;
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
// layout
@@ -590,9 +586,9 @@ __global__ void append_cache_kv_c4(
#pragma unroll
for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) {
cache_k_scale_smem[i] = cache_k_scale_now[i];
cache_k_zero_point_smem[i] = cache_k_zp_now[i] + static_cast<T>(136.f);
cache_k_zero_point_smem[i] = cache_k_zp_now[i] - static_cast<T>(136.f);
cache_v_scale_smem[i] = cache_v_scale_now[i];
cache_v_zero_point_smem[i] = cache_v_zp_now[i] + static_cast<T>(136.f);
cache_v_zero_point_smem[i] = cache_v_zp_now[i] - static_cast<T>(136.f);
}
smem_t k_smem(smem);
@@ -614,7 +610,7 @@ __global__ void append_cache_kv_c4(
// load k_smem 64 rows 128 cols
for (int fz = 0; fz < 2; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 once, need 1 iter
for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 noce, need 1 iter
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
k_smem_offset_w =
@@ -632,7 +628,7 @@ __global__ void append_cache_kv_c4(
// deal k_smem 64 rows 128 cols
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
uint32_t row_idx = wid * 16 + tid / 4;
for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 once, need 2 iter
for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 noce, need 2 iter
uint32_t col_idx = fy * 64 + tid % 4 * 2;
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
@@ -644,25 +640,25 @@ __global__ void append_cache_kv_c4(
convert_int4(frag_dq_T + 8, k_frag[2 * i + 1]);
if (row_idx < end_idx) {
k_tile_ptr0[0] = (frag_dq_T[0] - cache_k_zero_point_smem[col_idx]) * cache_k_scale_smem[col_idx];
k_tile_ptr0[1] = (frag_dq_T[1] - cache_k_zero_point_smem[col_idx + 1]) * cache_k_scale_smem[col_idx + 1];
k_tile_ptr0[8] = (frag_dq_T[2] - cache_k_zero_point_smem[col_idx + 8]) * cache_k_scale_smem[col_idx + 8];
k_tile_ptr0[9] = (frag_dq_T[3] - cache_k_zero_point_smem[col_idx + 9]) * cache_k_scale_smem[col_idx + 9];
k_tile_ptr0[16] = (frag_dq_T[8] - cache_k_zero_point_smem[col_idx + 16]) * cache_k_scale_smem[col_idx + 16];
k_tile_ptr0[17] = (frag_dq_T[9] - cache_k_zero_point_smem[col_idx + 17]) * cache_k_scale_smem[col_idx + 17];
k_tile_ptr0[24] = (frag_dq_T[10] - cache_k_zero_point_smem[col_idx + 24]) * cache_k_scale_smem[col_idx + 24];
k_tile_ptr0[25] = (frag_dq_T[11] - cache_k_zero_point_smem[col_idx + 25]) * cache_k_scale_smem[col_idx + 25];
k_tile_ptr0[0] = frag_dq_T[0] * cache_k_scale_smem[col_idx] + cache_k_zero_point_smem[col_idx];
k_tile_ptr0[1] = frag_dq_T[1] * cache_k_scale_smem[col_idx + 1] + cache_k_zero_point_smem[col_idx + 1];
k_tile_ptr0[8] = frag_dq_T[2] * cache_k_scale_smem[col_idx + 8] + cache_k_zero_point_smem[col_idx + 8];
k_tile_ptr0[9] = frag_dq_T[3] * cache_k_scale_smem[col_idx + 9] + cache_k_zero_point_smem[col_idx + 9];
k_tile_ptr0[16] = frag_dq_T[8] * cache_k_scale_smem[col_idx + 16] + cache_k_zero_point_smem[col_idx + 16];
k_tile_ptr0[17] = frag_dq_T[9] * cache_k_scale_smem[col_idx + 17] + cache_k_zero_point_smem[col_idx + 17];
k_tile_ptr0[24] = frag_dq_T[10] * cache_k_scale_smem[col_idx + 24] + cache_k_zero_point_smem[col_idx + 24];
k_tile_ptr0[25] = frag_dq_T[11] * cache_k_scale_smem[col_idx + 25] + cache_k_zero_point_smem[col_idx + 25];
}
if (row_idx + 8 < end_idx) {
k_tile_ptr1[0] = (frag_dq_T[4] - cache_k_zero_point_smem[col_idx]) * cache_k_scale_smem[col_idx];
k_tile_ptr1[1] = (frag_dq_T[5] - cache_k_zero_point_smem[col_idx + 1]) * cache_k_scale_smem[col_idx + 1];
k_tile_ptr1[8] = (frag_dq_T[6] - cache_k_zero_point_smem[col_idx + 8]) * cache_k_scale_smem[col_idx + 8];
k_tile_ptr1[9] = (frag_dq_T[7] - cache_k_zero_point_smem[col_idx + 9]) * cache_k_scale_smem[col_idx + 9];
k_tile_ptr1[16] = (frag_dq_T[12] - cache_k_zero_point_smem[col_idx + 16]) * cache_k_scale_smem[col_idx + 16];
k_tile_ptr1[17] = (frag_dq_T[13] - cache_k_zero_point_smem[col_idx + 17]) * cache_k_scale_smem[col_idx + 17];
k_tile_ptr1[24] = (frag_dq_T[14] - cache_k_zero_point_smem[col_idx + 24]) * cache_k_scale_smem[col_idx + 24];
k_tile_ptr1[25] = (frag_dq_T[15] - cache_k_zero_point_smem[col_idx + 25]) * cache_k_scale_smem[col_idx + 25];
k_tile_ptr1[0] = frag_dq_T[4] * cache_k_scale_smem[col_idx] + cache_k_zero_point_smem[col_idx];
k_tile_ptr1[1] = frag_dq_T[5] * cache_k_scale_smem[col_idx + 1] + cache_k_zero_point_smem[col_idx + 1];
k_tile_ptr1[8] = frag_dq_T[6] * cache_k_scale_smem[col_idx + 8] + cache_k_zero_point_smem[col_idx + 8];
k_tile_ptr1[9] = frag_dq_T[7] * cache_k_scale_smem[col_idx + 9] + cache_k_zero_point_smem[col_idx + 9];
k_tile_ptr1[16] = frag_dq_T[12] * cache_k_scale_smem[col_idx + 16] + cache_k_zero_point_smem[col_idx + 16];
k_tile_ptr1[17] = frag_dq_T[13] * cache_k_scale_smem[col_idx + 17] + cache_k_zero_point_smem[col_idx + 17];
k_tile_ptr1[24] = frag_dq_T[14] * cache_k_scale_smem[col_idx + 24] + cache_k_zero_point_smem[col_idx + 24];
k_tile_ptr1[25] = frag_dq_T[15] * cache_k_scale_smem[col_idx + 25] + cache_k_zero_point_smem[col_idx + 25];
}
col_idx += 32;
}
@@ -685,7 +681,7 @@ __global__ void append_cache_kv_c4(
tid % 2 * num_elems_per_128b<CacheT>();
// load v_smem 128 rows 64 rows
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
v_smem_offset_w =
@@ -704,7 +700,7 @@ __global__ void append_cache_kv_c4(
// deal v_smem 128 rows 64 cols
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
uint32_t kv_idx = fz * 64 + tid % 4 * 2;
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
// layout
@@ -715,36 +711,36 @@ __global__ void append_cache_kv_c4(
convert_int4(frag_dq_T, v_frag[2 * i]);
convert_int4(frag_dq_T + 8, v_frag[2 * i + 1]);
if (kv_idx < end_idx) {
v_tile_ptr0[0] = (frag_dq_T[0] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
v_tile_ptr1[0] = (frag_dq_T[4] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
v_tile_ptr0[0] = frag_dq_T[0] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
v_tile_ptr1[0] = frag_dq_T[4] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
}
if (kv_idx + 1 < end_idx) {
v_tile_ptr0[kv_t_stride] = (frag_dq_T[1] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
v_tile_ptr1[kv_t_stride] = (frag_dq_T[5] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
v_tile_ptr0[kv_t_stride] = frag_dq_T[1] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
v_tile_ptr1[kv_t_stride] = frag_dq_T[5] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
}
if (kv_idx + 8 < end_idx) {
v_tile_ptr0[8 * kv_t_stride] = (frag_dq_T[2] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
v_tile_ptr1[8 * kv_t_stride] = (frag_dq_T[6] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
v_tile_ptr1[8 * kv_t_stride] = frag_dq_T[6] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
}
if (kv_idx + 9 < end_idx) {
v_tile_ptr0[9 * kv_t_stride] = (frag_dq_T[3] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
v_tile_ptr1[9 * kv_t_stride] = (frag_dq_T[7] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
v_tile_ptr1[9 * kv_t_stride] = frag_dq_T[7] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
}
if (kv_idx + 16 < end_idx) {
v_tile_ptr0[16 * kv_t_stride] = (frag_dq_T[8] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
v_tile_ptr1[16 * kv_t_stride] = (frag_dq_T[12] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
v_tile_ptr0[16 * kv_t_stride] = frag_dq_T[8] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
v_tile_ptr1[16 * kv_t_stride] = frag_dq_T[12] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
}
if (kv_idx + 17 < end_idx) {
v_tile_ptr0[17 * kv_t_stride] = (frag_dq_T[9] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
v_tile_ptr1[17 * kv_t_stride] = (frag_dq_T[13] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
v_tile_ptr0[17 * kv_t_stride] = frag_dq_T[9] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
v_tile_ptr1[17 * kv_t_stride] = frag_dq_T[13] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
}
if (kv_idx + 24 < end_idx) {
v_tile_ptr0[24 * kv_t_stride] = (frag_dq_T[10] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
v_tile_ptr1[24 * kv_t_stride] = (frag_dq_T[14] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
v_tile_ptr0[24 * kv_t_stride] = frag_dq_T[10] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
v_tile_ptr1[24 * kv_t_stride] = frag_dq_T[14] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
}
if (kv_idx + 25 < end_idx) {
v_tile_ptr0[25 * kv_t_stride] = (frag_dq_T[11] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
v_tile_ptr1[25 * kv_t_stride] = (frag_dq_T[15] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
v_tile_ptr0[25 * kv_t_stride] = frag_dq_T[11] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
v_tile_ptr1[25 * kv_t_stride] = frag_dq_T[15] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
}
kv_idx += 32;
}
@@ -894,8 +890,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::optional<paddle::Tensor>& kv_signal_data,
const int kv_token_num,
const int max_seq_len,
const std::string& cache_quant_type,
const bool rope_3d) {
const std::string& cache_quant_type) {
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -958,34 +953,9 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
num_heads,
kv_num_heads,
max_seq_len,
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
rotary_embs.dims()[2],
head_dim,
rope_3d,
stream);
if (token_num < kv_token_num) {
AppendCacheKV<data_t, 128, 64>(
key_cache,
value_cache,
cache_k_dequant_scales.get(),
cache_v_dequant_scales.get(),
cache_k_zp.get(),
cache_v_zp.get(),
seq_lens_this_time,
seq_lens_decoder,
cu_seqlens_k,
block_tables,
cache_batch_ids,
cache_tile_ids,
cache_num_blocks,
max_blocks_per_seq,
kv_num_heads,
cache_quant_type,
&k,
&v,
stream
);
}
// write cache
if (cache_quant_type == "none") {
CascadeAppendWriteCacheKVQKV<data_t>(
@@ -1000,7 +970,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
stream,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8") {
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
CascadeAppendWriteCacheKVC8QKV<data_t, 128, 64>(
meta_data,
*const_cast<paddle::Tensor*>(&key_cache),
@@ -1018,7 +988,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
kv_num_blocks_data,
max_seq_len,
false, // is_scale_channel_wise
cache_quant_type,
cache_quant_type == "cache_fp8", // is_fp8
stream,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));
@@ -1068,6 +1038,30 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
}
}
}
if (token_num < kv_token_num) {
AppendCacheKV<data_t, 128, 64>(
key_cache,
value_cache,
cache_k_dequant_scales.get(),
cache_v_dequant_scales.get(),
cache_k_zp.get(),
cache_v_zp.get(),
seq_lens_this_time,
seq_lens_decoder,
cu_seqlens_k,
block_tables,
cache_batch_ids,
cache_tile_ids,
cache_num_blocks,
max_blocks_per_seq,
kv_num_heads,
cache_quant_type,
&k,
&v,
stream
);
}
return {q, k, v, qkv_out};
}

View File

@@ -18,168 +18,6 @@
#include "mma_tensor_op.cuh"
#include "utils.cuh"
template <typename T, int VecSize = 1, typename InT = T>
__global__ void append_speculate_cache_T_rope_qk_norm_kernel(
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
// head_size]
T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size,
// head_size // 2]
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
// head_size // 2]
T* __restrict__ q_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_decoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
const float*
qkv_out_scales, // [(num_heads + 2 * gqa_group_size) * head_size]
const T* qkv_biases, // [num_head + 2 * gqa_group_size, dim_head]
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int output_inner_dim,
const int head_size,
const int block_size,
const int elem_cnt,
const int gqa_group_size,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
LoadInT src_vec;
LoadFloat scale_vec;
LoadT bias_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
LoadFloat tmp_vec;
LoadFloat q_norm_vec;
LoadFloat k_norm_vec;
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.y;
int64_t all_head_dim = elem_cnt / head_size;
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size;
const int half_head_size = head_size / 2;
for (int global_hi = global_warp_idx; global_hi < all_head_dim; global_hi += all_warp_num) {
int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize;
const int token_id = linear_index / hidden_size;
const int ori_bi = batch_id_per_token[token_id];
if (seq_lens_decoder[ori_bi] == 0) continue;
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
const int write_seq_id =
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
if (write_seq_id == 0) continue;
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) {
printf(
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
}
const int block_offset = write_seq_id % block_size;
const int write_q_idx =
token_id * output_inner_dim * head_size + hi * head_size + h_bias;
const int bias_idx = hi * head_size + h_bias;
Load<InT, VecSize>(&qkv[linear_index], &src_vec);
if (qkv_biases) {
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
}
if (qkv_out_scales) {
Load<float, VecSize>(&qkv_out_scales[bias_idx], &scale_vec);
}
if (hi < num_heads + gqa_group_size) {
// q k rope
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
}
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// add_bias + rope
float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]);
if (qkv_out_scales) {
input_left *= scale_vec[2 * i];
input_right *= scale_vec[2 * i + 1];
}
if (qkv_biases) {
input_left = input_left + static_cast<float>(bias_vec[2 * i]);
input_right = input_right + static_cast<float>(bias_vec[2 * i + 1]);
}
if (hi < num_heads + gqa_group_size) {
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
tmp_vec[2 * i] = tmp1;
tmp_vec[2 * i + 1] = tmp2;
} else {
bias_vec[2 * i] = static_cast<T>(input_left);
bias_vec[2 * i + 1] = static_cast<T>(input_right);
}
}
if (hi < (num_heads + gqa_group_size)) {
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / head_size, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
if (hi < num_heads) {
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
}
} else {
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
}
}
}
if (hi < num_heads) {
// write q
Store<T, VecSize>(bias_vec, &q_out[write_q_idx]);
} else {
// write k/v
const int kv_head_idx = (hi - num_heads) % gqa_group_size;
const int tgt_idx = (block_idx * gqa_group_size * block_size * head_size +
kv_head_idx * block_size * head_size +
block_offset * head_size + h_bias);
// write
if (hi < num_heads + gqa_group_size) {
Store<T, VecSize>(bias_vec, &key_cache[tgt_idx]);
} else {
Store<T, VecSize>(bias_vec, &value_cache[tgt_idx]);
}
}
}
}
template <int VecSize = 4, int HeadDim = 128>
__global__ void append_clear_cache_int8_block(
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
@@ -355,8 +193,7 @@ __global__ void append_speculate_cache_rope_kernel(
const int head_size,
const int block_size,
const int elem_cnt,
const int gqa_group_size,
const bool rope_3d) {
const int gqa_group_size) {
using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>;
@@ -416,9 +253,8 @@ __global__ void append_speculate_cache_rope_kernel(
if (hi < num_heads + gqa_group_size) {
// q k rope
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
@@ -490,8 +326,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
const int head_size,
const int block_size,
const int elem_cnt,
const int gqa_group_size,
const bool rope_3d) {
const int gqa_group_size) {
using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>;
@@ -555,9 +390,8 @@ __global__ void append_speculate_cache_neox_rope_kernel(
if (hi < num_heads + gqa_group_size) {
// q k rope
const int64_t emb_idx = write_seq_id * head_size + h_bias;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2: emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
}
#pragma unroll
for (int i = 0; i < VecSize; i++) {
@@ -642,8 +476,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int gqa_group_size,
const bool rope_3d) {
const int gqa_group_size) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -689,9 +522,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
if (qkv_out_scales) {
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
}
@@ -751,11 +583,10 @@ __global__ void append_speculate_cache_int8_rope_kernel(
T scale;
if (head_idx < num_heads + gqa_group_size) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
scale = __ldg(&cache_k_scales[kv_head_idx]);
} else {
scale = __ldg(&cache_v_scales[kv_head_idx]);
@@ -877,8 +708,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int gqa_group_size,
const bool rope_3d) {
const int gqa_group_size) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -927,9 +757,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
// q rope
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
if (qkv_out_scales) {
Load<float, VecSize>(&qkv_out_scales[bias_idx_left],
&left_out_scale_vec);
@@ -1024,11 +853,10 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
T scale;
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
scale = __ldg(&cache_k_scales[kv_head_idx]);
#pragma unroll
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
@@ -1260,8 +1088,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int gqa_group_size,
const bool rope_3d) {
const int gqa_group_size) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -1318,9 +1145,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
// Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
@@ -1409,11 +1235,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
// &out_scale_vec2);
if (head_idx < num_heads + gqa_group_size) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx], &scale_vec1);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx + 8], &scale_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
@@ -1606,8 +1431,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
const int block_size,
const float max_bound,
const float min_bound,
const int gqa_group_size,
const bool rope_3d) {
const int gqa_group_size) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
@@ -1757,11 +1581,10 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
&right_out_scale_vec2);
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx],
&left_scale_vec1);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx + 8],

View File

@@ -15,78 +15,6 @@
#include "speculate_write_cache_with_rope_kernel.h"
#include "utils.cuh"
template <typename T, typename QKV_TYPE>
void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
T* key_cache,
T* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
const float* sin_emb,
const float* qkv_out_scales,
const T* qkv_biases,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int kv_num_heads,
const int dim_head,
const int block_size,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps,
const bool rope_3d) {
int output_inner_dim = num_heads + 2 * kv_num_heads;
const uint32_t elem_nums =
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
: token_num * (num_heads + 2 * kv_num_heads) * dim_head;
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
if (use_neox_style) {
PD_THROW(
"append_speculate_cache_rope_qk_norm not support neox rope yet");
} else {
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
append_speculate_cache_T_rope_qk_norm_kernel<T, PackSize>
<<<grid_size, block_dim, 0, stream>>>(qkv,
key_cache,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
max_seq_len,
max_blocks_per_seq,
num_heads,
output_inner_dim,
dim_head,
block_size,
elem_nums,
kv_num_heads,
q_norm_weight,
k_norm_weight,
rms_norm_eps,
rope_3d);
}
}
// rope + write
template <typename T, typename QKV_TYPE>
void append_speculate_cache_rope(const QKV_TYPE* qkv,
@@ -111,8 +39,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style,
const bool rope_3d) {
const bool use_neox_style) {
int output_inner_dim = num_heads + 2 * kv_num_heads;
const uint32_t elem_nums =
@@ -146,8 +73,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_speculate_cache_rope_kernel<T, PackSize>
<<<grid_size, threads_per_block, 0, stream>>>(
@@ -170,8 +96,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
kv_num_heads);
}
}
@@ -200,8 +125,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style,
const bool rope_3d) {
const bool use_neox_style) {
constexpr int num_warps = 4;
const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
@@ -243,8 +167,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_speculate_cache_int8_rope_kernel<T, 4, 0, 128, QKV_TYPE, IsFP8>
<<<grids, num_warps * 32, 0, stream>>>(qkv,
@@ -268,8 +191,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
}
}
@@ -300,8 +222,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style,
const bool rope_3d) {
const bool use_neox_style) {
constexpr int num_warps = 4;
const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
@@ -345,8 +266,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
} else {
append_speculate_cache_int4_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(qkv,
@@ -372,8 +292,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
kv_num_heads);
}
}
template <typename T, typename QKV_TYPE>
@@ -394,15 +313,11 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps) {
paddle::Tensor* value_cache_out) {
typedef cascade_attn_type_traits<T> traits_;
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
typedef typename traits_::type DataType_;
@@ -427,185 +342,142 @@ void SpeculateWriteCacheWithRoPEKernel(
? rotary_embs.get().data<float>() + max_seq_len * dim_head
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
}
if (q_norm_weight && k_norm_weight) {
if (cache_quant_type_str == "none") {
append_speculate_cache_rope_qk_norm(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
reinterpret_cast<const float*>(q_norm_weight.get().data<float>()),
reinterpret_cast<const float*>(k_norm_weight.get().data<float>()),
rms_norm_eps,
rope_3d);
} else {
PD_THROW(
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
}
if (cache_quant_type_str == "none") {
append_speculate_cache_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else if (cache_quant_type_str == "cache_int8") {
append_speculate_cache_int8_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else if (cache_quant_type_str == "cache_fp8") {
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_speculate_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else {
if (cache_quant_type_str == "none") {
append_speculate_cache_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int8") {
append_speculate_cache_int8_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_fp8") {
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_speculate_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else {
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, "
"cache_int4_zp]");
}
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, "
"cache_int4_zp]");
}
}
@@ -628,15 +500,11 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);
template void
SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
@@ -658,15 +526,11 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);
template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
const AppendAttnMetaData& meta_data,
@@ -687,15 +551,11 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);
template void
@@ -718,12 +578,8 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -35,12 +35,8 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len,
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -56,7 +56,6 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -104,6 +103,5 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -99,6 +98,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -101,6 +100,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -101,6 +100,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -100,6 +99,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -100,6 +99,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -43,7 +43,4 @@ EncoderWriteCacheWithRopeKernel<paddle::bfloat16, paddle::bfloat16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -42,7 +42,4 @@ template void EncoderWriteCacheWithRopeKernel<paddle::bfloat16, int>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -42,7 +42,4 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, paddle::float16>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -42,7 +42,4 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, int>(
cudaStream_t& stream,
paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
paddle::Tensor* value_cache_out);

View File

@@ -27,7 +27,6 @@ struct AppendAttnMetaData {
int head_dims;
int head_dims_v;
int max_blocks_per_seq;
const int *mask_offset = nullptr;
};
__forceinline__ __host__ __device__ int div_up(int a, int b) {
@@ -431,9 +430,6 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
} else if (group_size == 12) { \
constexpr size_t GROUP_SIZE = 12; \
__VA_ARGS__ \
} else if (group_size == 14) { \
constexpr size_t GROUP_SIZE = 14; \
__VA_ARGS__ \
} else if (group_size == 16) { \
constexpr size_t GROUP_SIZE = 16; \
__VA_ARGS__ \
@@ -441,15 +437,6 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
PD_THROW("not support the group_size", group_size); \
}
#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \
if (is_dynamic_cfp8) { \
constexpr bool IsDynamicC8 = true; \
__VA_ARGS__ \
} else { \
constexpr bool IsDynamicC8 = false; \
__VA_ARGS__ \
}
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
@@ -487,9 +474,6 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
if (causal) { \
constexpr bool CAUSAL = true; \
__VA_ARGS__ \
} else { \
constexpr bool CAUSAL = false; \
__VA_ARGS__ \
}
#define DISPATCH_ENABLE_PREFILL(enable_prefill, ENABLE_PREFILL, ...) \
@@ -575,37 +559,3 @@ template <typename T, bool IsFP8>inline __device__ static void convert_c8(T * re
convert_int8(result, source);
}
}
constexpr int kWarpSize = 32;
template<typename T>
inline __device__ void WelfordCombine1(T b_m2, T* m2) {
*m2 += b_m2;
}
template<typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) {
*m2 = thread_m2;
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
T b_m2 = __shfl_xor_sync(0xffffffff, *m2, mask);
WelfordCombine1(b_m2, m2);
}
}
template<typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) {
WelfordWarpReduce<T, thread_group_width>(thread_m2, m2);
}
template <typename T>
__inline__ __device__ T Rsqrt(T x);
template <>
__inline__ __device__ float Rsqrt<float>(float x) {
return rsqrt(x);
}
template <>
__inline__ __device__ double Rsqrt<double>(double x) {
return rsqrt(x);
}

View File

@@ -63,7 +63,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor &kv_num_blocks,
const paddle::Tensor &decoder_batch_ids,
const paddle::Tensor &decoder_tile_ids_per_batch,
const paddle::Tensor &decoder_num_blocks_cpu,
const paddle::Tensor &decoder_num_blocks,
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
const paddle::optional<paddle::Tensor> &rotary_embs,
const paddle::optional<paddle::Tensor> &attn_mask,
@@ -77,54 +77,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::optional<paddle::Tensor> &cache_v_zp,
const paddle::optional<paddle::Tensor> &out_linear_shifts,
const paddle::optional<paddle::Tensor> &out_linear_smooths,
const paddle::optional<paddle::Tensor> &mask_offset,
const paddle::optional<paddle::Tensor> &kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string &compute_dtype, const std::string &cache_quant_type_str,
const bool use_neox_rotary_style, const bool rope_3d,
const int max_input_length, const float quant_max_bound,
const float quant_min_bound, const float out_linear_in_scale,
const int encoder_block_shape_q, const int decoder_block_shape_q,
const int max_partition_size, const int encoder_max_partition_size,
const int speculate_max_draft_token_num, const bool causal,
const bool speculate_decoder);
void AppendAttentionWithOutput(
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids,
const paddle::Tensor &encoder_tile_ids_per_batch,
const paddle::Tensor &encoder_num_blocks,
const paddle::Tensor &kv_batch_ids,
const paddle::Tensor &kv_tile_ids_per_batch,
const paddle::Tensor &kv_num_blocks,
const paddle::Tensor &decoder_batch_ids,
const paddle::Tensor &decoder_tile_ids_per_batch,
const paddle::Tensor &decoder_num_blocks_cpu,
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
paddle::Tensor &fmha_out,
const paddle::optional<paddle::Tensor> &rotary_embs,
const paddle::optional<paddle::Tensor> &attn_mask,
const paddle::optional<paddle::Tensor> &qkv_bias,
const paddle::optional<paddle::Tensor> &qkv_out_scales,
const paddle::optional<paddle::Tensor> &cache_k_quant_scales,
const paddle::optional<paddle::Tensor> &cache_v_quant_scales,
const paddle::optional<paddle::Tensor> &cache_k_dequant_scales,
const paddle::optional<paddle::Tensor> &cache_v_dequant_scales,
const paddle::optional<paddle::Tensor> &cache_k_zp,
const paddle::optional<paddle::Tensor> &cache_v_zp,
const paddle::optional<paddle::Tensor> &out_linear_shifts,
const paddle::optional<paddle::Tensor> &out_linear_smooths,
const paddle::optional<paddle::Tensor> &mask_offset,
const paddle::optional<paddle::Tensor> &kv_signal_data,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps,
const std::string &compute_dtype, const std::string &cache_quant_type_str,
const bool use_neox_rotary_style, const bool rope_3d,
const int max_input_length, const float quant_max_bound,
@@ -154,8 +107,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::optional<paddle::Tensor> &cache_v_zp,
const paddle::optional<paddle::Tensor> &kv_signal_data,
const int kv_token_num, const int max_seq_len,
const std::string &cache_quant_type,
const bool rope_3d);
const std::string &cache_quant_type);
std::vector<paddle::Tensor>
PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,
@@ -172,29 +124,11 @@ paddle::Tensor FusedExpertMoeFunc(
const std::string &quant_method, const int moe_topk,
const bool norm_topk_prob, const bool group_moe);
std::vector<paddle::Tensor> MacheteMMKernel(
paddle::Tensor const& A, paddle::Tensor const& B,
paddle::optional<paddle::Tensor> const& maybe_group_scales,
paddle::optional<paddle::Tensor> const& maybe_group_zeros,
paddle::optional<paddle::Tensor> const& maybe_channel_scales,
paddle::optional<paddle::Tensor> const& maybe_token_scales,
std::string const& b_type_str,
std::string const& maybe_out_type_str,
int64_t const& maybe_group_size,
std::string const& maybe_schedule);
std::vector<paddle::Tensor> MachetePrepackBKernel(
paddle::Tensor const& B, std::string const& a_type_str, std::string const& b_type_str,
std::string const& maybe_group_scales_type_str);
std::vector<std::string> MacheteSupportedSchedules(
std::string const& a_type_str, std::string const& b_type_str);
std::vector<paddle::Tensor> MoeExpertDispatch(
const paddle::Tensor &input, const paddle::Tensor &gating_output,
const paddle::optional<paddle::Tensor> &gating_correction_bias,
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode);
const bool group_moe, const bool topk_only_mode);
std::vector<paddle::Tensor>
MoETopKSelectKernel(const paddle::Tensor &gating_logits,
@@ -254,9 +188,7 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency,
const int estimate_total_token_nums,
const int hadamard_block_size);
const std::string& quant_method, const bool used_in_ep_low_latency);
paddle::Tensor MoeExpertFFNWint2Func(
const paddle::Tensor& permute_input,
@@ -299,27 +231,12 @@ paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, const int device_id,
paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata,
const int layer_id);
void GetBlockShapeAndSplitKVBlock(
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
paddle::Tensor &decoder_batch_ids, // Inplace
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
paddle::Tensor &decoder_num_blocks_device, // Inplace
paddle::Tensor &decoder_chunk_size_device, // Inplace
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
paddle::Tensor &encoder_batch_ids, // Inplace
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &kv_batch_ids, // Inplace
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_kv_cpu, // Inplace, Pinned Memory
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int encoder_block_shape_q, const int decoder_block_shape_q,
const int group_size, const int block_size,
const int decoder_step_token_num);
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
@@ -349,12 +266,13 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
const paddle::Tensor &seq_lens,
const paddle::Tensor &end_ids,
const paddle::Tensor &next_tokens,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const bool beam_search);
void GetStopFlagsMultiSeqs(
const paddle::Tensor &topk_ids, const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens, const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len, const paddle::Tensor &end_ids);
void UpdateInputes(const paddle::Tensor &stop_flags,
const paddle::Tensor &not_need_stop, // only on cpu
@@ -388,11 +306,9 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &block_tables,
const paddle::Tensor &is_block_step,
const paddle::optional<paddle::Tensor> &draft_tokens,
const paddle::optional<paddle::Tensor> &step_draft_tokens,
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
const int block_size,
const int max_draft_tokens);
const int block_size);
paddle::Tensor
GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor,
@@ -402,7 +318,7 @@ std::vector<paddle::Tensor> ExtractTextTokenOutput(
const paddle::Tensor &max_seq_len, const paddle::Tensor &max_seq_len_index,
const paddle::Tensor &mm_token_num_len,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &hidden_states);
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &score_text);
std::vector<paddle::Tensor> MoEDeepGEMMPermute(const paddle::Tensor &x,
const paddle::Tensor &topk_idx,
@@ -416,8 +332,8 @@ std::vector<paddle::Tensor> MoEDeepGEMMDePermute(
const paddle::Tensor &topk_idx, const paddle::Tensor &topk_weights);
void TextImageIndexOut(const paddle::Tensor &token_type_ids,
paddle::Tensor &text_input,
paddle::Tensor &image_input);
const paddle::Tensor &text_input,
const paddle::Tensor &image_input);
void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
paddle::Tensor &image_input,
@@ -475,18 +391,23 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
const paddle::Tensor& query,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
const paddle::Tensor& encoder_num_blocks,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids_per_batch,
const paddle::Tensor& kv_num_blocks,
const paddle::Tensor& decoder_batch_ids,
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks_device,
const paddle::Tensor& decoder_chunk_size_device,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& decoder_num_blocks_cpu,
const paddle::Tensor& max_enc_len_this_time,
const paddle::Tensor& max_dec_len_this_time,
const paddle::Tensor& max_len_kv,
const paddle::optional<paddle::Tensor>& attn_mask,
@@ -600,7 +521,7 @@ paddle::Tensor FusedHadamardQuantFp8Func(
int64_t init_custom_all_reduce(const std::vector<int64_t>& fake_ipc_ptrs,
paddle::Tensor& rank_data, int64_t rank, bool full_nvlink);
void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, int64_t _fa,
void all_reduce(int64_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
int64_t reg_buffer, int64_t reg_buffer_sz_bytes);
void dispose(int64_t _fa);
@@ -683,7 +604,7 @@ void SpeculateVerify(
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &draft_tokens,
@@ -714,22 +635,6 @@ void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder);
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
const paddle::Tensor &block_tables,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &step_draft_tokens,
const paddle::Tensor &step_seq_lens_this_time,
const paddle::Tensor &accept_num,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &stop_nums,
const int block_size,
const int max_draft_tokens);
void NgramMatch(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &pre_ids,
@@ -744,20 +649,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
const int max_draft_tokens);
void HybridMtpNgram(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &draft_token_num,
const paddle::Tensor &draft_tokens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &max_dec_len,
const int max_ngram_size,
const int min_ngram_size,
const int max_draft_tokens);
// MTP
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
const paddle::Tensor& base_model_seq_lens_this_time,
@@ -773,12 +664,9 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& is_block_step,
const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids,
const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& base_model_seq_lens_this_time,
const paddle::Tensor& base_model_seq_lens_encoder,
const paddle::Tensor& base_model_seq_lens_decoder,
const paddle::Tensor& base_model_step_idx,
@@ -787,8 +675,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& base_model_draft_tokens,
const int max_draft_token,
const bool truncate_first_token,
const bool splitwise_prefill,
const bool kvcache_scheduler_v1);
const bool splitwise_prefill);
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
@@ -869,33 +756,6 @@ void SpeculateStepPaddle(
const int encoder_decoder_block_num,
const int max_draft_tokens);
void MergePrefillDecodeOutput(
const paddle::Tensor &encoder_res,
const paddle::Tensor &decoder_res,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &cu_seq_q,
const int head_num,
const int head_dim,
const int max_token);
std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
const paddle::Tensor &top_p,
const paddle::optional<paddle::Tensor> &top_k,
int64_t seed);
std::vector<paddle::Tensor> TopKRenorm(const paddle::Tensor &probs,
const paddle::Tensor &top_k);
std::vector<paddle::Tensor> MinPSamplingFromProbs(const paddle::Tensor &probs,
const paddle::Tensor &min_p);
void SaveOutMmsgStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
bool save_each_rank);
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
@@ -949,7 +809,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
* append_attention
*/
m.def("append_attention", &AppendAttention, "append attention function");
m.def("append_attention_with_output", &AppendAttentionWithOutput, "append attention with output function");
/**
* gqa_rope_write_cache.cu
* gqa_rope_write_cache
@@ -981,7 +840,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("moe_expert_dispatch", &MoeExpertDispatch, py::arg("input"),
py::arg("gating_output"), py::arg("gating_correction_bias"),
py::arg("w4a8_in_scale"), py::arg("moe_topk"), py::arg("group_moe"),
py::arg("moe_quant_type"), py::arg("topk_only_mode"), "moe export dispatch function");
py::arg("topk_only_mode"), "moe export dispatch function");
/**
* moe/fused_moe/ep_moe_prefill_func.cu
@@ -1005,33 +864,12 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("per_token_quant_padding", &PerTokenQuantPadding, py::arg("input"),
py::arg("block_size"),
"per token per block quant and padding transpose scale");
"per token per block quant and padding tranpose scale");
m.def("masked_per_token_quant", &MaskedPerTokenQuant, py::arg("input"),
py::arg("recv_expert_count"), py::arg("block_size"),
"per token per block quant");
#ifdef ENABLE_MACHETE
/*machete/machete_mm.cu
* machete_mm
*/
m.def("machete_mm", &MacheteMMKernel, py::arg("A"), py::arg("B"), py::arg("maybe_group_scale"),
py::arg("maybe_group_zeros"), py::arg("maybe_channel_scales"), py::arg("maybe_token_scales"),
py::arg("b_type_str"), py::arg("maybe_out_type_str"), py::arg("maybe_group_size"),
py::arg("maybe_schedule"),
"machete mm function");
/*machete/machete_prepack_B.cu
* machete_prepack_B
*/
m.def("machete_prepack_B", &MachetePrepackBKernel, "machete prepacked B function");
/*machete/machete_supported_schedules.cu
* machete_supported_schedules
*/
m.def("machete_supported_schedules", &MacheteSupportedSchedules, "machete supported schedules function");
#endif
/**
* moe/fused_moe/moe_topk_select.cu
* moe_topk_select
@@ -1048,7 +886,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("moe_expert_ffn", &MoeExpertFFNFunc, "moe export ffn function");
/**
* moe/fused_moe/moe_expert_ffn_wint2.cu
* moe/fused_moe/moe_ffn_wint2.cu
* moe_expert_ffn_wint2
*/
m.def("moe_expert_ffn_wint2", &MoeExpertFFNWint2Func, "moe export ffn wint2 function");
@@ -1116,6 +954,12 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("set_stop_value_multi_ends", &GetStopFlagsMulti,
"update_inputs function");
/**
* stop_generation_multi_stop_seqs.cu
* set_stop_value_multi_seqs
*/
m.def("set_stop_value_multi_seqs", &GetStopFlagsMultiSeqs,
"update_inputs function");
/**
* update_inputs.cu
@@ -1245,7 +1089,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("speculate_verify",&SpeculateVerify, "speculate_verify function");
m.def("speculate_update",&SpeculateUpdate, "Speculate Update Kernel");
m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function");
m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function");
@@ -1253,12 +1097,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function");
m.def("speculate_schedule_cache",&SpeculateScheduleCache, "SpeculateScheduleCache function");
m.def("ngram_match", &NgramMatch, "ngram_match function");
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");
m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function");
m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function");
@@ -1272,14 +1112,4 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function");
m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function");
m.def("merge_prefill_decode_output", &MergePrefillDecodeOutput, "merge_prefill_decode_output function");
m.def("rejection_top_p_sampling", &TopPSamplingReject, "rejection_top_p_sampling function");
m.def("top_k_renorm_probs", &TopKRenorm, "top_k_renorm_probs function");
m.def("min_p_sampling", &MinPSamplingFromProbs, "min_p_sampling function");
m.def("save_output", &SaveOutMmsgStatic, "save_output function");
}

View File

@@ -49,7 +49,7 @@ fptr_t init_custom_all_reduce(const std::vector<fptr_t>& fake_ipc_ptrs,
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
*/
void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, fptr_t _fa,
void all_reduce(fptr_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
auto stream = inp.stream();
@@ -163,12 +163,3 @@ fptr_t open_mem_handle(paddle::Tensor& mem_handle) {
void free_shared_buffer(fptr_t buffer) {
CUDACHECK(cudaFree(reinterpret_cast<void*>(buffer)));
}
PD_BUILD_STATIC_OP(all_reduce)
.Inputs({"inp",
"out"})
.Outputs({"new_out"})
.Attrs({"_fa: int64_t", "_reg_buffer: int64_t", "reg_buffer_sz_bytes: int64_t"})
.SetInplaceMap({{"out", "new_out"}})
.SetKernelFn(PD_KERNEL(all_reduce));

View File

@@ -6,8 +6,6 @@
// clang-format off
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
#include "helper.h"
// clang-format on
/*

View File

@@ -133,18 +133,10 @@ public:
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint2b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value; // 64
private:
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint2b_t>::value;
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; // 8
public:
// using Layout = layout::ColumnMajor;
// static constexpr int ElementsPerAccess = 16; // at least 4-bytes
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint2b_t>::value; // 64
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::RowMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<TypeA>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename TypeA, typename Arch>

View File

@@ -18,12 +18,14 @@
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h"
#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
////////////////////////////////////////////////////////////////////////////////
@@ -376,23 +378,38 @@ template <
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
private:
using Mma = DefaultWint2xMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
};
template <
@@ -424,23 +441,38 @@ struct DefaultMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
private:
using Mma = DefaultWint2xMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
};
} // namespace threadblock

View File

@@ -19,7 +19,7 @@
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
namespace cutlass {
namespace gemm {
@@ -379,23 +379,38 @@ template <
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
private:
using Mma = DefaultWint2xMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
};
template <
@@ -427,23 +442,38 @@ struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmen
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
private:
using Mma = DefaultWint2xMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
};
} // namespace threadblock

View File

@@ -1,182 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
/// Partial specialization:
///
/// A: row-major
/// B: uint2b_t, column-major
/// Operator: tensor op class
///
/// This uses the default warp-level operator given tile sizes
template <
/// Shape of threadblock-scoped matrix multiply operator (concept:
/// GemmShape)
typename Shape_,
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
typename WarpShape_,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape_,
/// Data type of A operand
typename ElementA_,
/// Data type of accumulator
typename ElementC_,
/// Layout of accumulator
typename LayoutC_,
/// Number of stages
int Stages,
/// Operation performed by MMA
typename Operator_,
/// Cache operation of operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Cache operation of operand B
cutlass::arch::CacheOperation::Kind CacheOpB>
struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
layout::RowMajor, uint2b_t, layout::ColumnMajor,
ElementC_, LayoutC_, arch::OpClassTensorOp, Stages,
Operator_, false, CacheOpA, CacheOpB> {
using Shape = Shape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using ElementA = ElementA_;
using LayoutA = layout::RowMajor;
using ElementB = uint2b_t;
using LayoutB = layout::ColumnMajor;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
static int const kStages = Stages;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
/// Number of warps present
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
Shape::kN / WarpShape::kN,
Shape::kK / WarpShape::kK>;
// Divisility requirements
static_assert(
!(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
/// Number of threads per warp
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
/// Size of a threadblock-scoped access
static int const kAccessSizeInBits = 128;
/// Number of threads total
static int const kThreads = WarpCount::kCount * kWarpSize;
/// Size of a threadblock-scoped access of B
static constexpr int kMaxThreadsForB =
(Shape::kK * Shape::kN * sizeof_bits<ElementB>::value) / kAccessSizeInBits;
static constexpr int kThreadsForB =
kMaxThreadsForB > kThreads ? kThreads : kMaxThreadsForB;
/// Default Operator
using Operator = Operator_;
// Warp thread arrangement
static int const kWarpThreadArrangementContiguousA =
Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementA>::value);
static int const kWarpThreadArrangementStridedA =
kWarpSize / kWarpThreadArrangementContiguousA;
static int const kWarpThreadArrangementContiguousB =
Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementB>::value);
static int const kWarpThreadArrangementStridedB =
kWarpSize / kWarpThreadArrangementContiguousB;
//
// Shared memory layouts
//
using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise<
sizeof_bits<ElementA>::value, Shape::kK>;
// Shared memory layout
using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise<
sizeof_bits<ElementB>::value, Shape::kK>;
//
// Iterators to write to shared memory
//
/// ThreadMap of iterator A
using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<Shape::kK, Shape::kM>, kThreads,
layout::PitchLinearShape<kWarpThreadArrangementContiguousA,
kWarpThreadArrangementStridedA>,
kAccessSizeInBits / sizeof_bits<ElementA>::value>;
/// Shared memory iterator to A operand
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kM, Shape::kK>, ElementA, SmemLayoutA, 0,
IteratorThreadMapA>;
/// ThreadMap of iterator B
using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<Shape::kK, Shape::kN>, kThreadsForB,
layout::PitchLinearShape<kWarpThreadArrangementContiguousB,
kWarpThreadArrangementStridedB>,
kAccessSizeInBits / sizeof_bits<ElementB>::value>;
/// Shared memory iterator to B operand
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 1,
IteratorThreadMapB>;
//
// Warp-level matrix multiply operator
//
// Define the warp-level tensor op
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
/// Policy used to define MmaPipelined
using MmaPolicy = MmaPolicy<MmaTensorOp, MatrixShape<0, 0>,
MatrixShape<0, 0>, WarpCount::kK>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -1,246 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
#include "cutlass_extensions/gemm/threadblock/default_mma_core.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
template <typename ThreadblockShape, typename ElementT, int GroupSize>
struct DefaultQuantParamsIterators {
private:
static constexpr int kAlignment = 128 / sizeof_bits<ElementT>::value;
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
static constexpr int kRows =
(GroupSize == -1) ? 1 : (ThreadblockShape::kK + GroupSize - 1) / GroupSize;
static constexpr int kColumns = ThreadblockShape::kN;
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<kColumns, kRows>,
kColumns / kAlignment, kAlignment>;
public:
using Iterator = cutlass::transform::threadblock::PredicatedTileIterator<
MatrixShape<kRows, kColumns>, ElementT, layout::RowMajor, 0,
IteratorThreadMap, kAlignment>;
using SmemIterator = Iterator;
};
template <typename ThreadblockShape, int GroupSize>
struct DefaultQuantParamsIterators<ThreadblockShape, uint4b_t, GroupSize> {
private:
static constexpr int kAlignment = 32 / sizeof_bits<uint4b_t>::value;
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
static constexpr int kRows =
(GroupSize == -1) ? 1 : (ThreadblockShape::kK + 2 * GroupSize - 1) / (2 * GroupSize);
static constexpr int kColumns =
(GroupSize == -1) ? ThreadblockShape::kN : ThreadblockShape::kN * 2;
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<kColumns, kRows>,
kColumns / kAlignment, kAlignment>;
public:
using AccessType = cutlass::Array<uint4b_t, kAlignment>;
using Iterator = cutlass::transform::threadblock::PredicatedTileAccessIterator<
MatrixShape<kRows, kColumns>, uint4b_t, layout::RowMajor,
0, IteratorThreadMap, AccessType>;
using SmemIterator = Iterator;
};
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Operator class tag
typename OperatorClass_,
/// Tag indicating architecture to tune for
typename ArchTag_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone>
struct DefaultWint2xMma;
////////////////////////////////////////////////////////////////////////////////
template <
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Stages in GEMM
int kStages,
/// Operator performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear>
struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
kStages, Operator, SharedMemoryClear>
{
public:
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementB, uint2b_t>::value,
"Element B must be uint2b_t");
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
"Mma multistage must dequantize after ldsm");
using ElementSuperScale = ElementA;
using ElementLocalScale = uint4b_t;
using ElementCodeScaleZp = float;
static constexpr int kGroupSize = 64;
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
ElementA, LayoutA, ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
AccessTypeA>;
private:
static constexpr int kColumnsInterleaved = LayoutB::kColumnsInterleaved;
static constexpr int kRowsPerTile = LayoutB::kRowsPerTile;
static_assert(!(MmaCore::Shape::kN % kColumnsInterleaved), "ThreadblockShape must be disivle by kColumnsInterleaved");
static_assert(kRowsPerTile == MmaCore::Shape::kK, "");
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using WarpArrangement = typename ThreadMapB::Detail::WarpThreadArrangement;
static_assert(!(WarpArrangement::kStrided % kColumnsInterleaved), "");
using IteratorShapeB = MatrixShape<
MmaCore::Shape::kK * kColumnsInterleaved, MmaCore::Shape::kN / kColumnsInterleaved>;
using InterleavedThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<IteratorShapeB::kRow, IteratorShapeB::kColumn>,
ThreadMapB::kThreads,
layout::PitchLinearShape<WarpArrangement::kContiguous * kColumnsInterleaved,
WarpArrangement::kStrided / kColumnsInterleaved>,
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
public:
// Define iterators over tiles from the B operand
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
IteratorShapeB, ElementB, layout::ColumnMajor, 0, InterleavedThreadMapB,
AccessTypeB>;
private:
// Define iterators over tiles from extra quant params for B operand
using IteratorSuperScale = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementSuperScale, -1>::Iterator;
using SmemIteratorSuperScale = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementSuperScale, -1>::SmemIterator;
using IteratorLocalScale = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementLocalScale, kGroupSize>::Iterator;
using SmemIteratorLocalScale = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementLocalScale, kGroupSize>::SmemIterator;
using IteratorCodeScaleZp = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementCodeScaleZp, -1>::Iterator;
using SmemIteratorCodeScaleZp = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementCodeScaleZp, -1>::Iterator;
public:
using QuantParamsAccessor = Wint2ParamsAccessor<
ElementA, ThreadblockShape, IteratorSuperScale, SmemIteratorSuperScale,
IteratorLocalScale, SmemIteratorLocalScale,
IteratorCodeScaleZp, SmemIteratorCodeScaleZp, kStages, kGroupSize>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<
typename MmaCore::Shape,
IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA,
IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB,
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy,
kStages, QuantParamsAccessor, SharedMemoryClear>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -63,8 +63,8 @@ template <
typename Policy_,
/// Number of stages,
int Stages,
/// Size of extra quantized params
typename QuantParamsShape>
/// Used for partial specialization
typename Enable = bool>
class Wint2xMmaBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
@@ -89,18 +89,10 @@ public:
GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM operations
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
/// Number of warp-level GEMM operations per load for B
static constexpr int kWarpGemmIterationsPerLoadForB =
Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), "");
static constexpr int kWarpLoadIterationsForB =
kWarpGemmIterations / kWarpGemmIterationsPerLoadForB;
/// Number of stages
static int const kStages = Stages;
@@ -112,6 +104,8 @@ public:
using TensorRefB =
TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
// using TensorRefZippedB = TensorRef<uint8_t, typename Operator::LayoutB>;
static_assert(kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
@@ -136,11 +130,20 @@ public:
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
/// Shape of the B matrix operand in shared memory
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
using ShapeB = MatrixShape<Shape::kK + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;
/// Shape of all quant params in shared memory
using QuantParamsShapeB = QuantParamsShape;
// w uint8; local_scale uint8;
constexpr static int kZippedRowsPerStages =
Shape::kK / 4 + (Shape::kK + 127) / 128;
// code_scale float; code_zp float; super_scale ElementB
constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) +
sizeof_bits<typename Operator::ElementB>::value / 8;
using ZippedShapeB = MatrixShape<kColumnWiseParamsRows + kZippedRowsPerStages * kStages, Shape::kN>;
using NopaddingShapeB = MatrixShape<Shape::kK, Shape::kN>;
public:
//
@@ -153,8 +156,12 @@ public:
/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
/// Buffer for extra quant params of B operand
AlignedBuffer<uint8_t, QuantParamsShapeB::kCount> operand_quant_params_B;
/// Buffer for quanted B operand
AlignedBuffer<uint8_t, ZippedShapeB::kCount> operand_zipped_B;
/// Buffer for unzip B operand
AlignedBuffer<typename Operator::ElementB, NopaddingShapeB::kCount>
operand_unzip_B;
public:
//
@@ -184,6 +191,14 @@ public:
TensorRefB operand_B_ref() {
return TensorRefB{operand_B.data(), LayoutB()};
}
CUTLASS_HOST_DEVICE
uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); }
CUTLASS_HOST_DEVICE
typename Operator::ElementB *operand_unzip_B_ptr() {
return operand_unzip_B.data();
}
};
protected:

View File

@@ -45,8 +45,7 @@
#include "cutlass_extensions/arch/memory_copy_sm80.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -87,15 +86,15 @@ template <
typename Policy_,
/// Number of stages,
int Stages,
/// Accessor for extra quantized params
typename QuantParamsAccessor_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone>
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Used for partial specialization
typename Enable = bool>
class Wint2xMmaMultistage :
public Wint2xMmaBase<Shape_, Policy_, Stages, typename QuantParamsAccessor_::QuantParamsShape> {
public Wint2xMmaBase<Shape_, Policy_, Stages> {
public:
///< Base class
using Base = Wint2xMmaBase<Shape_, Policy_, Stages, typename QuantParamsAccessor_::QuantParamsShape>;
using Base = Wint2xMmaBase<Shape_, Policy_, Stages>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
@@ -108,11 +107,8 @@ public:
using LayoutC = LayoutC_;
///< Policy describing tuning details
using Policy = Policy_;
/// Accessor for extra quantized params
using QuantParamsAccessor = QuantParamsAccessor_;
using QuantArguments = typename QuantParamsAccessor::Arguments;
static constexpr int kInterleave = IteratorB::Shape::kRow / Shape::kK;
using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
@@ -133,18 +129,6 @@ public:
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
//using LayoutScale = typename QuantParamsAccessor::IteratorSuperScale::Layout;
using LayoutScale = layout::RowMajor;
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
using WarpDequantizer =
warp::MmaTensorOpWin2xDequantizer<Operator,
typename Base::WarpGemm,
Operand::kB,
typename WarpTransformedFragmentB::Element,
LayoutScale,
QuantParamsAccessor::kGroupSize>;
static_assert(sizeof(WarpDequantizer) > 0, "WarpDequantizer template instantiation failed");
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
@@ -190,37 +174,18 @@ public:
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
using FragmentSuperScale = typename WarpDequantizer::FragmentSuperScale;
using FragmentCodeScaleZp = typename WarpDequantizer::FragmentCodeScaleZp;
using FragmentLocalScale = typename WarpDequantizer::FragmentLocalScale;
/// Temporary accumulator to facilitate staged-accumulation
FragmentC tmp_accum_;
/// Pair of A fragments used to overlap shared memory loads and math instructions
WarpTransformedFragmentA warp_frag_A_[2];
WarpLoadedFragmentA warp_loaded_frag_A_[2];
WarpTransformedFragmentA warp_transformed_frag_A_[2];
/// Pair of B fragments used to overlap shared memory loads and math instructions
WarpLoadedFragmentB warp_loaded_frag_B_;
WarpTransformedFragmentB warp_frag_B_[2];
/// channel-wise quant params
FragmentCodeScaleZp warp_frag_code_scale_;
FragmentCodeScaleZp warp_frag_code_zp_;
FragmentSuperScale warp_frag_super_scale_;
/// group-wise quant params
FragmentLocalScale warp_frag_local_scale_;
WarpLoadedFragmentB warp_loaded_frag_B_[2];
WarpTransformedFragmentB warp_transformed_frag_B_[2];
};
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool IsTileInterleaveLayout =
layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
static_assert(!IsTileInterleaveLayout || (IsTileInterleaveLayout && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
private:
@@ -237,18 +202,17 @@ public:
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Accessor for extra quant params for B
QuantParamsAccessor quant_params_accessor_B_;
// Wint2 unzip operator
WarpDequantizer warp_dequantizer_;
/// Shared memory write stage index
int smem_write_stage_idx_;
/// Shared memory read stage index
int smem_read_stage_idx_;
uint8_t* column_wise_smem_ptr_B_;
uint8_t* smem_zipped_ptr_B_;
int smem_zipped_bytes_per_stage_B_;
public:
/// Construct from tensor references
@@ -262,15 +226,10 @@ public:
int warp_idx,
///< ID of each thread within a warp
int lane_idx
) : Base(shared_storage, thread_idx, warp_idx, lane_idx),
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
quant_params_accessor_B_(shared_storage.operand_quant_params_B.data(), thread_idx, warp_idx, lane_idx),
warp_dequantizer_(quant_params_accessor_B_.super_scale_ref(),
quant_params_accessor_B_.local_scale_ref(),
quant_params_accessor_B_.code_scale_ref(),
quant_params_accessor_B_.code_zp_ref(),
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx),
smem_write_stage_idx_(0),
smem_read_stage_idx_(0)
{
@@ -291,6 +250,11 @@ public:
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr();
smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn;
smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn;
}
/// Advance shared memory read-iterators to the next stage
@@ -302,22 +266,28 @@ public:
if (smem_read_stage_idx_ == Base::kStages) {
// Wrap back around to the 'start' of the circular buffer in shared memory
this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpLoadIterationsForB, 0});
// this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
smem_read_stage_idx_ = 0;
}
this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
}
/// Advance global memory read-iterators and shared memory write-iterators to the stage
template <typename TileDequanterB>
CUTLASS_DEVICE
void advance_smem_write_stage(IteratorA &iterator_A, IteratorB &iterator_B)
void advance_smem_write_stage(
IteratorA &iterator_A,
IteratorB &iterator_B,
TileDequanterB &tile_dequanter_B)
{
// Advance global iterators
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
//iterator_B.add_tile_offset({1, 0});
tile_dequanter_B.AddTileOffset({1, 0});
// Advance shared iterators
smem_iterator_A_.add_tile_offset({0, 1});
smem_iterator_B_.add_tile_offset({1, 0});
//smem_iterator_B_.add_tile_offset({1, 0});
// Increment shared memory write stage index
++smem_write_stage_idx_;
@@ -325,7 +295,7 @@ public:
if (smem_write_stage_idx_ == Base::kStages) {
// Wrap back around to the 'start' of the circular buffer in shared memory
smem_iterator_A_.add_tile_offset({0, -Base::kStages});
smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
//smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx_ = 0;
}
}
@@ -368,14 +338,9 @@ public:
}
}
template <bool GlobalToSharedB>
CUTLASS_DEVICE
void copy_tiles_and_advance_B(IteratorB &iterator_B, int group_start_B = 0) {
if constexpr (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
if (threadIdx.x >= IteratorB::ThreadMap::kThreads) {
return;
}
}
iterator_B.set_iteration_index(group_start_B *
IteratorB::kAccessesPerVector);
this->smem_iterator_B_.set_iteration_index(group_start_B);
@@ -395,14 +360,13 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B.get();
bool is_valid = (threadIdx.x < IteratorB::ThreadMap::kThreads) ? iterator_B.valid() : false;
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, is_valid);
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
} else {
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, is_valid);
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
}
++iterator_B;
@@ -411,6 +375,7 @@ public:
++this->smem_iterator_B_;
}
}
__syncthreads();
}
CUTLASS_DEVICE
@@ -434,6 +399,8 @@ public:
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
@@ -444,12 +411,9 @@ public:
}
}
template <bool GlobalToSharedB, bool InitStage>
CUTLASS_DEVICE
void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) {
if (threadIdx.x >= IteratorB::ThreadMap::kThreads) {
return;
}
iterator_B.set_iteration_index(0);
this->smem_iterator_B_.set_iteration_index(0);
@@ -469,23 +433,35 @@ public:
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
if (InitStage) {
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
} else {
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
} else {
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
}
}
++iterator_B;
}
++this->smem_iterator_B_;
}
__syncthreads();
}
/// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching
/// the global fragments needed by the first kStages-1 threadblock mainloop iterations
template <typename TileDequanterB>
CUTLASS_DEVICE
void prologue(
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
QuantArguments &mma_quant_args, ///< iterators for extra quant params for B
TileDequanterB &tile_dequanter_B,
int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining
{
// Issue several complete stages
@@ -500,18 +476,11 @@ public:
copy_tiles_and_advance_per_stage_A(iterator_A);
// Async copy zipped B to shared memory.
copy_tiles_and_advance_per_stage_B(iterator_B);
// Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale.
if (stage == 0) {
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<true>(mma_quant_args, stage);
} else {
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(mma_quant_args, stage);
}
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
column_wise_smem_ptr_B_, stage);
// Move to the next write stage
advance_smem_write_stage(iterator_A, iterator_B);
quant_params_accessor_B_.advance_smem_write_stage(mma_quant_args);
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
// Defines the boundary of a stage of cp.async.
cutlass::arch::cp_async_fence();
@@ -541,10 +510,6 @@ public:
++last_smem_iterator_A;
}
if (threadIdx.x >= IteratorB::ThreadMap::kThreads) {
return;
}
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
typename IteratorB::AccessType zero_B;
@@ -577,57 +542,57 @@ public:
}
/// Perform a threadblock mainloop iteration of matrix multiply-accumulate
template <typename TileDequanterB>
CUTLASS_DEVICE
void mac_loop_iter(
PipeState &pipe_state, ///< [in|out] loop-carried pipeline state
FragmentC &accum, ///< [in|out] destination accumulator tile
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
QuantArguments &mma_quant_args, ///< iterators for extra quant params for B
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining
TileDequanterB &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining
int stage)
{
const int mma_stage = stage - Base::kStages + 1;
// Unroll the warp-level MMA tiles of a threadblock's mainloop iteration
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
int warp_k_compute_offset_B = warp_mma_k % Base::kWarpGemmIterationsPerLoadForB;
if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) {
// Load the next warp-tile's B fragment from shared memory
this->warp_tile_iterator_B_.set_kgroup_index(((warp_mma_k + 1) % Base::kWarpGemmIterations) / Base::kWarpLoadIterationsForB);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
++this->warp_tile_iterator_B_;
}
// load next-tile of group-wise local_scale from shared memory
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
warp_dequantizer_.load(pipe_state.warp_frag_local_scale_);
}
// CUTLASS_TRACE_DEVICE(" [MMa] stage=%d, warp_mma_k=%d", stage, warp_mma_k);
// Load the next warp-tile's A fragment from shared memory
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
// dequantizes next warp-tile
warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_,
pipe_state.warp_frag_code_scale_,
pipe_state.warp_frag_code_zp_,
pipe_state.warp_frag_super_scale_,
pipe_state.warp_loaded_frag_B_,
pipe_state.warp_frag_B_[(warp_mma_k + 1) % 2],
((warp_mma_k == Base::kWarpGemmIterations - 1) ? (mma_stage + 1) : mma_stage) * Shape::kK,
(warp_mma_k + 1) % Base::kWarpGemmIterationsPerLoadForB);
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
// Unpack and dequant the first stage of B.
int unpack_stage = stage - Base::kStages + 2;
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
column_wise_smem_ptr_B_, unpack_stage);
// Copy dequatized data to shared memory used by mma core.
copy_tiles_and_advance_per_stage_B<false, false>(iterator_B);
}
// Load the next warp-tile's B fragment from shared memory
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_B_;
// Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary
if (warp_mma_k > 0) {
warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
pipe_state.warp_loaded_frag_A_[warp_mma_k % 2],
pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]);
}
// Execute the current warp-tile of MMA operations
if constexpr (Detail::kStagedAccumulation) {
if (Detail::kStagedAccumulation) {
warp_mma_(
pipe_state.tmp_accum_,
pipe_state.warp_frag_A_[warp_mma_k % 2],
pipe_state.warp_frag_B_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
pipe_state.tmp_accum_
);
@@ -639,22 +604,22 @@ public:
} else {
warp_mma_(
accum,
pipe_state.warp_frag_A_[warp_mma_k % 2],
pipe_state.warp_frag_B_[warp_mma_k % 2],
accum);
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
accum
);
}
// Except for the last warp-tile, all warp-tiles issue their share of
// global->shared fragment copies
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
int group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
int group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
copy_tiles_and_advance_B(iterator_B, group_start_iteration_B);
if (warp_mma_k == 0) {
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(mma_quant_args, stage);
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
column_wise_smem_ptr_B_, stage);
}
}
@@ -663,15 +628,9 @@ public:
// - moves to the next global fetch stage
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
// Performs the last warp-tile's share of global->shared fragment copies
if constexpr (Detail::AsyncCopyIterationsPerStageA >= Base::kWarpGemmIterations) {
int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
}
int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
if constexpr (Detail::AsyncCopyIterationsPerStageB >= Base::kWarpGemmIterations) {
int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
copy_tiles_and_advance_B(iterator_B, group_start_iteration_B);
}
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
@@ -680,66 +639,69 @@ public:
gmem_wait();
// Move to the next global fetch stage
advance_smem_write_stage(iterator_A, iterator_B);
quant_params_accessor_B_.advance_smem_write_stage(mma_quant_args);
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
advance_smem_read_stage();
int byte_offset = quant_params_accessor_B_.advance_smem_read_stage();
warp_dequantizer_.add_pointer_offset(byte_offset);
// Disable global fetching when done with global fetch iterations
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
}
// The last warp-tile also converts the shared memory fragments used by
// the first warp-tile of the next iteration, if necessary (so we can
// immediately start issuing MMA instructions at the top of the loop )
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2],
pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2],
pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2],
pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
}
}
}
/// Perform the specified number of threadblock mainloop iterations of matrix
/// multiply-accumulate. Assumes prologue has been initiated.
template <typename TileDequanterB>
CUTLASS_DEVICE
void gemm_iters(
int gemm_k_iterations, ///< number of threadblock mainloop iterations
FragmentC &accum, ///< [in|out] accumulator tile
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
QuantArguments &mma_quant_args)
IteratorB &iterator_B,
TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory
{
PipeState pipe_state;
// Unpack and dequant the first stage of B.
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0);
// Disable global fetching if done with global fetch iterations
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0);
// Load first warp-tile's B fragment from shared memory
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
++this->warp_tile_iterator_B_;
warp_dequantizer_.load(pipe_state.warp_frag_code_scale_,
pipe_state.warp_frag_code_zp_,
pipe_state.warp_frag_super_scale_);
warp_dequantizer_.load(pipe_state.warp_frag_local_scale_);
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
// Load first warp-tile's A fragment from shared memory
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[0]);
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]);
++this->warp_tile_iterator_A_;
// Dequantize B to in register
warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_,
pipe_state.warp_frag_code_scale_,
pipe_state.warp_frag_code_zp_,
pipe_state.warp_frag_super_scale_,
pipe_state.warp_loaded_frag_B_,
pipe_state.warp_frag_B_[0],
0,
0);
// Copy dequatized data to shared memory used by mma core.
copy_tiles_and_advance_per_stage_B<false, true>(iterator_B);
if constexpr (Detail::kStagedAccumulation) {
// Load first warp-tile's B fragment from shared memory
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]);
++this->warp_tile_iterator_B_;
// Transform, if necessary, the first warp-tile's shared memory fragments
warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[0],
pipe_state.warp_transformed_frag_B_[0],
pipe_state.warp_loaded_frag_A_[0],
pipe_state.warp_loaded_frag_B_[0]);
if (Detail::kStagedAccumulation) {
pipe_state.tmp_accum_.clear();
}
@@ -753,13 +715,13 @@ public:
accum,
iterator_A,
iterator_B,
mma_quant_args,
tile_dequanter_B,
gemm_k_iterations,
stage);
stage += 1;
}
if constexpr (Detail::kStagedAccumulation) {
if (Detail::kStagedAccumulation) {
plus<FragmentC> plus_accum;
accum = plus_accum(accum, pipe_state.tmp_accum_);
}
@@ -799,12 +761,14 @@ public:
else
{
this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)});
this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
//this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0});
}
smem_read_stage_idx_ = smem_write_stage_idx_;
}
/// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory.
template <typename TileDequanterB>
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
@@ -815,13 +779,13 @@ public:
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
///< iterators for extra quant params for B
QuantArguments mma_quant_args,
///< pre-load and dequantize B to shared memory
TileDequanterB tile_dequanter_B,
///< initial value of accumulator
FragmentC const &src_accum) {
// Prologue (start fetching iterations of global fragments into shared memory)
prologue(iterator_A, iterator_B, mma_quant_args, gemm_k_iterations);
prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations);
// Wait until we have at least one completed global fetch stage
gmem_wait();
@@ -830,7 +794,7 @@ public:
accum = src_accum;
// Perform the MAC-iterations
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, mma_quant_args);
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B);
}
};

View File

@@ -1,315 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/trace.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
template <
/// Original data type
typename T,
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterators over super scales in global memory
typename IteratorSuperScale_,
/// Iterators over super scales in shared memory
typename SmemIteratorSuperScale_,
/// Iterators over local scales in global memory
typename IteratorLocalScale_,
/// Iterators over local scales in shared memory
typename SmemIteratorLocalScale_,
/// Iterators over code scales and zps in global memory
typename IteratorCodeScaleZp_,
/// Iterators over code scales and zps in shared memory
typename SmemIteratorCodeScaleZp_,
/// Number of stages,
int Stages_,
/// Group size for quantization
int GroupSize_>
class Wint2ParamsAccessor {
public:
static_assert(platform::is_same<T, half_t>::value || platform::is_same<T, bfloat16_t>::value,
"T must be fp16 or bf16");
using ElementType = T;
using Shape = Shape_;
using IteratorSuperScale = IteratorSuperScale_;
using SmemIteratorSuperScale = SmemIteratorSuperScale_;
using IteratorLocalScale = IteratorLocalScale_;
using SmemIteratorLocalScale = SmemIteratorLocalScale_;
using IteratorCodeScaleZp = IteratorCodeScaleZp_;
using SmemIteratorCodeScaleZp = SmemIteratorCodeScaleZp_;
constexpr static int kStages = Stages_;
constexpr static int kGroupSize = GroupSize_;
using ElementSuperScale = typename IteratorSuperScale::Element;
using LayoutSuperScale = typename IteratorSuperScale::Layout;
/// local_scale uint4 and group-wise
using ElementLocalScale = typename IteratorLocalScale::Element;
using LayoutLocalScale = typename IteratorLocalScale::Layout;
static_assert(platform::is_same<ElementLocalScale, uint4b_t>::value,
"local_scale's type must be uint4b_t.");
using ElementCodeScaleZp = typename IteratorCodeScaleZp::Element;
using LayoutCodeScaleZp = typename IteratorCodeScaleZp::Layout;
/// 2 uint4b_t values are stored in a single uint8_t
constexpr static int kStagesPerLocalScaleLoad = 2 * kGroupSize / Shape::kK;
constexpr static int kLocalScaleRows =
IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn * sizeof_bits<ElementLocalScale>::value / 8 / Shape::kN;
using SmemElement = uint8_t;
constexpr static int kSmemRows =
kLocalScaleRows * kStages + sizeof(ElementSuperScale) + sizeof(ElementCodeScaleZp) * 2;
constexpr static int kSmemColumns = Shape::kN;
using QuantParamsShape = MatrixShape<kSmemRows, kSmemColumns>;
constexpr static int kSuperScaleSmemOffset = 0;
constexpr static int kCodeScaleSmemOffset = kSmemColumns * sizeof(ElementSuperScale);
constexpr static int kCodeZpSmemOffset = kCodeScaleSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp);
constexpr static int kLocalScaleSmemOffset = kCodeZpSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp);
/// TensorRef type for loading element from a tensor
using SuperTensorRef = cutlass::TensorRef<ElementSuperScale, LayoutSuperScale>;
using LocalTensorRef = cutlass::TensorRef<ElementLocalScale, LayoutLocalScale>;
using CodeTensorRef = cutlass::TensorRef<ElementCodeScaleZp, LayoutCodeScaleZp>;
struct Arguments {
IteratorSuperScale iterator_super_scale;
IteratorLocalScale iterator_local_scale;
IteratorCodeScaleZp iterator_code_scale;
IteratorCodeScaleZp iterator_code_zp;
int local_scale_pointer_offset;
CUTLASS_DEVICE
Arguments(IteratorSuperScale iterator_super_scale,
IteratorLocalScale iterator_local_scale,
IteratorCodeScaleZp iterator_code_scale,
IteratorCodeScaleZp iterator_code_zp,
int local_scale_pointer_offset)
: iterator_super_scale(iterator_super_scale),
iterator_local_scale(iterator_local_scale),
iterator_code_scale(iterator_code_scale),
iterator_code_zp(iterator_code_zp),
local_scale_pointer_offset(local_scale_pointer_offset) {}
};
private:
//
// Data members
//
/// Begin address of shared memory
uint8_t* smem_pointer_;
/// Iterator to write threadblock-scoped tile of super scale operand to shared memory
SmemIteratorSuperScale smem_iterator_super_scale_;
/// Iterator to write threadblock-scoped tile of local scale operand to shared memory
SmemIteratorLocalScale smem_iterator_local_scale_;
/// Iterator to write threadblock-scoped tile of code scale operand to shared memory
SmemIteratorCodeScaleZp smem_iterator_code_scale_;
/// Iterator to write threadblock-scoped tile of code zp operand to shared memory
SmemIteratorCodeScaleZp smem_iterator_code_zp_;
/// Shared memory write stage index
int smem_write_stage_idx_;
/// Shared memory read stage index
int smem_read_stage_idx_;
CUTLASS_DEVICE
ElementSuperScale* get_super_scale_smem_ptr() {
return reinterpret_cast<ElementSuperScale*>(smem_pointer_ + kSuperScaleSmemOffset);
}
CUTLASS_DEVICE
ElementLocalScale* get_local_scale_smem_ptr() {
return reinterpret_cast<ElementLocalScale*>(smem_pointer_ + kLocalScaleSmemOffset);
}
CUTLASS_DEVICE
ElementCodeScaleZp* get_code_scale_smem_ptr() {
return reinterpret_cast<ElementCodeScaleZp*>(smem_pointer_ + kCodeScaleSmemOffset);
}
CUTLASS_DEVICE
ElementCodeScaleZp* get_code_zp_smem_ptr() {
return reinterpret_cast<ElementCodeScaleZp*>(smem_pointer_ + kCodeZpSmemOffset);
}
public:
/// Construct from tensor references
CUTLASS_DEVICE
Wint2ParamsAccessor(
///< prointer of shared memory
uint8_t* smem_pointer,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: smem_pointer_(smem_pointer),
smem_iterator_super_scale_(LayoutSuperScale(IteratorSuperScale::Shape::kColumn),
get_super_scale_smem_ptr(), {1, IteratorSuperScale::Shape::kColumn}, thread_idx),
smem_iterator_local_scale_(LayoutLocalScale(IteratorLocalScale::Shape::kColumn),
get_local_scale_smem_ptr(), {1, IteratorLocalScale::Shape::kColumn}, thread_idx),
smem_iterator_code_scale_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn),
get_code_scale_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx),
smem_iterator_code_zp_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn),
get_code_zp_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx),
smem_write_stage_idx_(0),
smem_read_stage_idx_(0) {}
CUTLASS_DEVICE
SuperTensorRef super_scale_ref() {
return {get_super_scale_smem_ptr(), LayoutSuperScale(IteratorSuperScale::Shape::kColumn)};
}
CUTLASS_DEVICE
LocalTensorRef local_scale_ref() {
return {get_local_scale_smem_ptr(), LayoutLocalScale(IteratorLocalScale::Shape::kColumn)};
}
CUTLASS_DEVICE
CodeTensorRef code_scale_ref() {
return {get_code_scale_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)};
}
CUTLASS_DEVICE
CodeTensorRef code_zp_ref() {
return {get_code_zp_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)};
}
template <bool IsFirstStage>
CUTLASS_DEVICE
void copy_tiles_and_advance_per_stage(Arguments &quant_args, int stage) {
if constexpr (IsFirstStage) {
// Load channel-wise super_scale to shared memory, which only needs to be done once.
typename IteratorSuperScale::Fragment tb_frag_super_scale;
tb_frag_super_scale.clear();
quant_args.iterator_super_scale.load(tb_frag_super_scale);
this->smem_iterator_super_scale_.store(tb_frag_super_scale);
// Load channel-wise code_scale to shared memory, which only needs to be done once.
typename IteratorCodeScaleZp::Fragment tb_frag_code_scale;
tb_frag_code_scale.clear();
quant_args.iterator_code_scale.load(tb_frag_code_scale);
this->smem_iterator_code_scale_.store(tb_frag_code_scale);
// Load channel-wise code_zp to shared memory, which only needs to be done once.
typename IteratorCodeScaleZp::Fragment tb_frag_code_zp;
tb_frag_code_zp.clear();
quant_args.iterator_code_zp.load(tb_frag_code_zp);
this->smem_iterator_code_zp_.store(tb_frag_code_zp);
}
if ((stage % kStagesPerLocalScaleLoad) == 0) {
// Load group-wise local_scale to shared memory, which only needs to be done at each stage.
// Since 2 uint4b_t values of local_scale are saved in a single uint8_t, local_scale needs to be loaded once every two stages.
using AccessType = typename IteratorLocalScale::AccessType;
cutlass::arch::CacheOperation::Kind const kCacheOp = (sizeof_bits<AccessType>::value == 128)
? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always;
quant_args.iterator_local_scale.set_iteration_index(0);
this->smem_iterator_local_scale_.set_iteration_index(0);
// Async Copy for local_scale
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < IteratorLocalScale::ThreadMap::Iterations::kCount; ++j) {
AccessType *dst_ptr =
reinterpret_cast<AccessType *>(this->smem_iterator_local_scale_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorLocalScale::kAccessesPerVector; ++v) {
auto gmem_ptr = quant_args.iterator_local_scale.get();
int const kSrcBytes =
sizeof_bits<typename IteratorLocalScale::Element>::value *
IteratorLocalScale::ThreadMap::kElementsPerAccess /
IteratorLocalScale::kAccessesPerVector / 8;
cutlass::arch::cp_async<kSrcBytes, kCacheOp>(
dst_ptr + v, gmem_ptr, quant_args.iterator_local_scale.valid());
}
++quant_args.iterator_local_scale;
}
++this->smem_iterator_local_scale_;
}
}
CUTLASS_DEVICE
void advance_smem_write_stage(Arguments &quant_args) {
if (smem_write_stage_idx_ % kStagesPerLocalScaleLoad == 0) {
// Advance global iterators
quant_args.iterator_local_scale.add_pointer_offset(quant_args.local_scale_pointer_offset);
// Advance shared iterators
int smem_pointer_offset = IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn;
smem_iterator_local_scale_.add_pointer_offset(smem_pointer_offset);
}
// Increment shared memory write stage index
++smem_write_stage_idx_;
if (smem_write_stage_idx_ == kStagesPerLocalScaleLoad * kStages) {
// Wrap back around to the 'start' of the circular buffer in shared memory
int pointer_offset = - kStages * IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn;
smem_iterator_local_scale_.add_pointer_offset(pointer_offset);
smem_write_stage_idx_ = 0;
}
}
CUTLASS_DEVICE
int advance_smem_read_stage() {
int byte_offset = 0;
++smem_read_stage_idx_;
if (smem_read_stage_idx_ % kStagesPerLocalScaleLoad == 0) {
byte_offset = kLocalScaleRows * kSmemColumns;
}
if (smem_read_stage_idx_ == kStagesPerLocalScaleLoad * kStages) {
smem_read_stage_idx_ = 0;
byte_offset = - (kStages - 1) * kLocalScaleRows * kSmemColumns;
}
return byte_offset;
}
CUTLASS_DEVICE
int clear_mask(Arguments &quant_args, bool cond) {
quant_args.iterator_local_scale.clear_mask(cond);
}
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,130 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cutlass/gemm_coord.h"
#include "cutlass/trace.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_unzip.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
template <typename ElementT, typename ScaleElementT, int Rows, int Columns,
int Stages, int NumThreads, WintQuantMethod Method>
struct TileDequanter {
using WeightQuantTraits = WintQuantTraits<ElementT, Method>;
using MmaElementT = typename WeightQuantTraits::MmaWeightType;
using QuantArguments = typename WeightQuantTraits::Arguments;
using UnzipAndDequantFunctor =
UnzipAndDequantFunctor<MmaElementT, Method, Rows, Columns, NumThreads>;
static constexpr bool kUseSharedMemory = true;
static constexpr int kRows = Rows;
static constexpr int kColumns = Columns;
static constexpr int kStages = Stages;
MmaElementT *out_smem_ptr{nullptr};
char *pointer{nullptr};
int64_t ldm{0};
cutlass::MatrixCoord tb_offset;
cutlass::MatrixCoord extent;
ScaleElementT *super_scale_ptr{nullptr};
cutlass::MatrixCoord tb_offset_scale;
QuantArguments quant_args;
int64_t block_start_rows[kStages];
bool need_preload{true};
UnzipAndDequantFunctor unzip_functor;
CUTLASS_DEVICE
TileDequanter(MmaElementT *out_smem_ptr, char *pointer, int64_t ldm,
const cutlass::MatrixCoord &extent,
const cutlass::MatrixCoord &tb_offset,
ScaleElementT *super_scale_ptr,
const cutlass::MatrixCoord &tb_offset_scale,
const QuantArguments &quant_args)
: out_smem_ptr(out_smem_ptr), pointer(pointer), ldm(ldm), extent(extent),
tb_offset(tb_offset), super_scale_ptr(super_scale_ptr),
tb_offset_scale(tb_offset_scale), quant_args(quant_args) {}
CUTLASS_DEVICE
MmaElementT *GetOutPtr() { return out_smem_ptr; }
CUTLASS_DEVICE
void AddTileOffset(const cutlass::MatrixCoord &tile_offset) {
tb_offset.row() += tile_offset.row() * kRows;
tb_offset.column() += tile_offset.column() * kColumns;
tb_offset_scale.column() += tile_offset.column() * kColumns;
}
CUTLASS_DEVICE
void Load(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
int zipped_row = WeightQuantTraits::CaclPackedDim(tb_offset.row());
if (tb_offset.row() >= extent.row() ||
tb_offset.column() >= extent.column()) {
return;
}
block_start_rows[stage % kStages] = tb_offset.row();
using ZippedT = typename WeightQuantTraits::WeightType;
ZippedT *in_ptr = reinterpret_cast<ZippedT *>(pointer) + zipped_row * ldm +
tb_offset.column();
ScaleElementT *scale_ptr = super_scale_ptr + tb_offset_scale.column();
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
const uint8_t *local_scale_ptr = quant_args.local_scale_ptr +
(tb_offset.row() / 128) * ldm +
tb_offset_scale.column();
const float *code_scale_ptr =
quant_args.code_scale_ptr + tb_offset_scale.column();
const float *code_zp_ptr =
quant_args.code_zp_ptr + tb_offset_scale.column();
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
unzip_functor.LoadAsync(in_ptr, local_scale_ptr, code_scale_ptr, code_zp_ptr,
scale_ptr, &args, ldm, need_preload);
need_preload = false;
} else {
// CUTLASS_TRACE_DEVICE("Not Supported!");
}
}
CUTLASS_DEVICE
void UnpackAndDequant(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
int64_t block_start_row = block_start_rows[stage % kStages];
if (block_start_row >= extent.row()) {
return;
}
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
unzip_functor.ComputeVectorized(args, out_smem_ptr, block_start_row);
} else {
// CUTLASS_TRACE_DEVICE("Not Supported!");
}
}
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -41,9 +41,12 @@
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
namespace cutlass {
namespace gemm {
namespace warp {
namespace cutlass
{
namespace gemm
{
namespace warp
{
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -78,7 +81,7 @@ private:
// Shape for computing the FP16s
using ComputeInstructionShape = InstructionShape_;
// Chosen so we get K=16 for int8, K=32 for int4, K=64 for int2.
// Chosen so we get K=16 for int8 and K=32 for int4.
static constexpr int LoadInstructionK = 128 / sizeof_bits<ElementB>::value;
// Shape for loading the narrow data type from shared memory

View File

@@ -58,12 +58,15 @@
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace warp {
namespace cutlass
{
namespace gemm
{
namespace warp
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting Tensor Cores, for the case when A is floating point and B is quantized integer.
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
@@ -294,235 +297,6 @@ public:
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting Tensor Cores, for the case when A is floating point and B is quantized integer.
/// Specialization for B of uint2b_t.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Data type of A elements
typename ElementA_,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA_,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB_,
/// Element type of C matrix
typename ElementC_,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC_,
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
typename Policy_,
/// Instruction shape to override shared memory iterators with
typename SharedMemoryInstructionShape_,
/// Number of partitions along K dimension
int PartitionsK_,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor>
class MmaTensorOpComputeBWithF16<
Shape_,
ElementA_,
LayoutA_,
uint2b_t,
LayoutB_,
ElementC_,
LayoutC_,
Policy_,
SharedMemoryInstructionShape_,
PartitionsK_,
AccumulatorsInRowMajor>
{
public:
/// Shape of warp-level matrix operation (concept: GemmShape)
using Shape = Shape_;
/// Data type of multiplicand A
using ElementA = ElementA_;
/// Layout of multiplicand A
using LayoutA = LayoutA_;
/// Data type of multiplicand B
using ElementB = uint2b_t;
/// Layout of multiplicand B
using LayoutB = LayoutB_;
/// Data type of accumulator matrix C
using ElementC = ElementC_;
/// Layout of accumulator matrix C
using LayoutC = LayoutC_;
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
using Policy = Policy_;
/// Underlying matrix multiply operator (concept: arch::Mma)
using ArchMmaOperator = typename Policy::Operator;
/// Indicates math operator
using MathOperator = typename ArchMmaOperator::Operator;
/// Architecture tag from underlying instruction
using ArchTag = typename ArchMmaOperator::ArchTag;
static_assert((platform::is_same<typename ArchMmaOperator::ElementA, half_t>::value
&& platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value)
|| (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value
&& platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value
&& ArchTag::kMinComputeCapability >= 80),
"MmaTensorOpCvtBToA only supports underlying HMMA/QMMA");
static_assert(platform::is_same<ElementA, half_t>::value
|| (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80),
"MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+");
/// Indicates class of matrix operator
using OperatorClass = arch::OpClassTensorOp;
/// Shape of underlying instruction
using InstructionShape = typename ArchMmaOperator::Shape;
/// Instruction shape to override shared memory iterators with
using SharedMemoryInstructionShape = SharedMemoryInstructionShape_;
static_assert(
SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load");
static_assert(
SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load");
static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK;
static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), "");
/// Complex transform on A operand
static ComplexTransform const kTransformA = ComplexTransform::kNone;
/// Complex transform on B operand
static ComplexTransform const kTransformB = ComplexTransform::kNone;
/// Number of threads participating in warp-level matrix product
static int const kThreadCount = 32;
/// Number of partitions along K dimension
static int const kPartitionsK = PartitionsK_;
public:
/// Iterates over the A operand in memory
using IteratorA
= MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
MatrixShape<InstructionShape::kM, InstructionShape::kK>, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
/// Storage for A tile
using FragmentA = typename IteratorA::Fragment;
/// Storage for transformed A tile
using TransformedFragmentA = Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
/// Iterates over the B operand in memory
using IteratorB = MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB,
LayoutB, MatrixShape<SharedMemoryInstructionShape::kK, InstructionShape::kN>, Policy::OpDelta::kRow,
kThreadCount, kPartitionsK>;
/// Storage for B tile
using FragmentB = typename IteratorB::Fragment;
/// Storage for transformed B tile
using TransformedFragmentB =
Array<typename ArchMmaOperator::ElementB, FragmentB::kElements / kExpansionFactor>;
/// Iterates over the C operand in memory
using IteratorC = MmaTensorOpAccumulatorTileIterator<MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
typename ArchMmaOperator::Shape, typename Policy::OpDelta>;
/// Storage for C tile
using FragmentC = typename IteratorC::Fragment;
/// Number of mma operations performed
using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
(Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>;
public:
/// Underlying matrix multiply operator (concept: arch::Mma)
ArchMmaOperator mma;
public:
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
MmaTensorOpComputeBWithF16() {}
/// Performs a warp-level matrix multiply-accumulate operation
CUTLASS_DEVICE
void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C) const
{
using MmaOperandA = typename ArchMmaOperator::FragmentA;
using MmaOperandB = typename ArchMmaOperator::FragmentB;
using MmaOperandC = typename ArchMmaOperator::FragmentC;
D = C;
MmaOperandA const* ptr_A = reinterpret_cast<MmaOperandA const*>(&A);
MmaOperandB const* ptr_B = reinterpret_cast<MmaOperandB const*>(&B);
MmaOperandC* ptr_D = reinterpret_cast<MmaOperandC*>(&D);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
// Serpentine visitation order maximizing reuse of Rb
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < MmaIterations::kColumn; ++n)
{
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m)
{
int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
if (AccumulatorsInRowMajor)
{ // matrix B is reordered
mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n],
ptr_D[n + m_serpentine * MmaIterations::kColumn]);
}
else
{
mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n],
ptr_D[m_serpentine + n * MmaIterations::kRow]);
}
}
}
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
// Serpentine visitation order maximizing reuse of Ra
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < MmaIterations::kColumn; ++n)
{
int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
if (AccumulatorsInRowMajor)
{ // matrix B is reordered
mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine],
ptr_D[n_serpentine + m * MmaIterations::kColumn]);
}
else
{
mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine],
ptr_D[m + n_serpentine * MmaIterations::kRow]);
}
}
}
#else
assert(0);
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp

View File

@@ -1,442 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines iterators used by warp-level matrix multiply operations
targeting Tensor Cores.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/functional.h"
#include "cutlass/platform/platform.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
namespace cutlass {
namespace gemm {
namespace warp {
namespace detail {
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<bfloat16_t> {
using Type = __nv_bfloat16;
using DualType = __nv_bfloat162;
};
template <>
struct DataTypeTraits<half_t> {
using Type = __half;
using DualType = __half2;
};
template <typename T, int N, typename Enable = void>
struct LocalScaleConverter {
using FragmentSource = Array<uint8_t, N>;
using FragmentResult = Array<T, N>;
CUTLASS_DEVICE
static void Apply(FragmentSource const& local_scale_frag,
FragmentResult const& super_scale_frag,
FragmentResult& scale_frag,
int shift_bit) {
constexpr uint32_t kLocalScaleMask = 0xf;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
int32_t shifted_value = (static_cast<int32_t>(local_scale_frag[i]) >> shift_bit) & kLocalScaleMask;
scale_frag[i] = static_cast<T>(shifted_value) * super_scale_frag[i];
}
}
};
template <int N>
struct LocalScaleConverter<half_t, N, typename platform::enable_if<N % 4 == 0>::type> {
using FragmentSource = Array<uint8_t, N>;
using FragmentResult = Array<half_t, N>;
CUTLASS_DEVICE
static void Apply(FragmentSource const& local_scale_frag,
FragmentResult const& super_scale_frag,
FragmentResult& scale_frag,
int shift_bit) {
constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
constexpr uint32_t MASK = 0x000f000f;
// 2^10 = 1024
constexpr uint32_t I4s_TO_FP16s_MAGIC_NUM = 0x64006400;
// -2^10 = -1024
constexpr uint32_t FP16_BIAS = 0xE400E400;
// 1.0
constexpr uint32_t FP16_ONE = 0x3C003C00;
__half2* scale_ptr = reinterpret_cast<__half2 *>(&scale_frag);
__half2 const* super_scale_ptr = reinterpret_cast<__half2 const*>(&super_scale_frag);
uint32_t const* local_scale_ptr = reinterpret_cast<uint32_t const*>(&local_scale_frag);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 4; ++i) {
int i4s = local_scale_ptr[i] >> shift_bit;
// unpack: 0, 1
int32_t low = __byte_perm(i4s, i4s, 0xF1F0);
int32_t unpack0 = lop3<immLut>(low, MASK, I4s_TO_FP16s_MAGIC_NUM);
// unpack: 2, 3
int32_t high = __byte_perm(i4s, i4s, 0xF3F2);
int32_t unpack1 = lop3<immLut>(high, MASK, I4s_TO_FP16s_MAGIC_NUM);
__half2 scale0 = __hfma2(*reinterpret_cast<__half2*>(&unpack0),
*reinterpret_cast<const __half2*>(&FP16_ONE),
*reinterpret_cast<const __half2*>(&FP16_BIAS));
__half2 scale1 = __hfma2(*reinterpret_cast<__half2*>(&unpack1),
*reinterpret_cast<const __half2*>(&FP16_ONE),
*reinterpret_cast<const __half2*>(&FP16_BIAS));
scale_ptr[2 * i] = __hmul2(scale0, super_scale_ptr[2 * i]);
scale_ptr[2 * i + 1] = __hmul2(scale1, super_scale_ptr[2 * i + 1]);
}
}
};
template <int N>
struct LocalScaleConverter<bfloat16_t, N, typename platform::enable_if<N % 4 == 0>::type> {
using FragmentSource = Array<uint8_t, N>;
using FragmentResult = Array<bfloat16_t, N>;
CUTLASS_DEVICE
static void Apply(FragmentSource const& local_scale_frag,
FragmentResult const& super_scale_frag,
FragmentResult& scale_frag,
int shift_bit) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA;
constexpr uint32_t MASK = 0x000F000F;
constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
constexpr uint32_t BF16_BIAS = 0xC300C300;
constexpr uint32_t BF16_ONE = 0x3F803F80;
__nv_bfloat162* scale_ptr = reinterpret_cast<__nv_bfloat162 *>(&scale_frag);
__nv_bfloat162 const* super_scale_ptr = reinterpret_cast<__nv_bfloat162 const*>(&super_scale_frag);
uint32_t const* local_scale_ptr = reinterpret_cast<uint32_t const*>(&local_scale_frag);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 4; ++i) {
int i4s = local_scale_ptr[i] >> shift_bit;
// unpack: 0, 1
int32_t low = __byte_perm(i4s, i4s, 0xF1F0);
int32_t unpack0 = lop3<immLut>(low, MASK, I4s_TO_BF16s_MAGIC_NUM);
// unpack: 2, 3
int32_t high = __byte_perm(i4s, i4s, 0xF3F2);
int32_t unpack1 = lop3<immLut>(high, MASK, I4s_TO_BF16s_MAGIC_NUM);
nv_bfloat162 scale0 = __hfma2(*reinterpret_cast<nv_bfloat162*>(&unpack0),
*reinterpret_cast<const nv_bfloat162*>(&BF16_ONE),
*reinterpret_cast<const nv_bfloat162*>(&BF16_BIAS));
nv_bfloat162 scale1 = __hfma2(*reinterpret_cast<nv_bfloat162*>(&unpack1),
*reinterpret_cast<const nv_bfloat162*>(&BF16_ONE),
*reinterpret_cast<const nv_bfloat162*>(&BF16_BIAS));
scale_ptr[2 * i] = __hmul2(scale0, super_scale_ptr[2 * i]);
scale_ptr[2 * i + 1] = __hmul2(scale1, super_scale_ptr[2 * i + 1]);
}
#else
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
// numerous conversion instructions in GEMM main loop.
arch::device_breakpoint();
#endif
}
};
} // namespace detail
////////////////////////////////////////////////////////////////////////////////
template <
/// Matrix multiply operator
typename MmaOperator_,
/// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Operand identity
Operand Operand,
/// Data type of Scale elements
typename ElementOperand_,
/// Layout of operand
typename Layout_,
/// Group size for quantization
int GroupSize_,
///
typename Enable = void>
class MmaTensorOpWin2xDequantizer {
//static_assert(false, "Not Supported!");
};
////////////////////////////////////////////////////////////////////////////////
// Bfloat specialization for Ampere
template <
/// Underlying matrix multiply operator (concept: MmaTensorOp)
typename MmaOperator_,
/// Shape of the warp level matrix multiply (concept: GemmShape)
typename Shape_,
/// Data type of Scale elements
typename ElementOperand_,
/// Group size for quantization
int GroupSize_>
class MmaTensorOpWin2xDequantizer<
MmaOperator_,
Shape_,
Operand::kB,
ElementOperand_,
layout::RowMajor,
GroupSize_>
//typename platform::enable_if<MmaOperator_::ArchTag::kMinComputeCapability >= 80
// && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type>
{
public:
static_assert(platform::is_same<ElementOperand_, half_t>::value || platform::is_same<ElementOperand_, bfloat16_t>::value,
"T must be fp16 or bf16");
/// Mma Operator
using MmaOperator = MmaOperator_;
// The architecture specific mma ooperator being used
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
// Mma Instruction Shape
using InstructionShape = typename ArchMmaOperator::Shape;
/// Warp mma shape
using Shape = Shape_;
/// Type of mma operand
using ElementOperand = ElementOperand_;
/// Layout of the scales in shared memory
using Layout = layout::RowMajor;
/// Group size for quantization
static constexpr int kGroupSize = GroupSize_;
/// Type of input
using ElementB = typename MmaOperator::FragmentB::Element;
static_assert(platform::is_same<ElementB, uint2b_t>::value, "ElementB must be uint2b_t");
/// Type of the scales
using ElementLocalScale = uint4b_t;
using ElementSuperScale = ElementOperand;
using ElementCodeScaleZp = float;
// Fragment to hold scale data to apply to B before mma
// We need 1 fp16 per matrix iteration in the N dimension
static constexpr int kWarpIterationsAlongN = MmaOperator::MmaIterations::kColumn;
// use uint8_t to save 2 4-bits local scales
using FragmentLocalScale = Array<uint8_t, kWarpIterationsAlongN>;
using FragmentSuperScale = Array<ElementSuperScale, kWarpIterationsAlongN>;
using FragmentCodeScaleZp = Array<ElementCodeScaleZp, kWarpIterationsAlongN>;
/// Fragment to hold B data before Mma
using FragmentInput = Array<ElementB, MmaOperator::FragmentB::kElements>;
// This is the ratio of the load instruction vs the compute instruction.
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
static constexpr int kNumPacks = sizeof_bits<uint8_t>::value / sizeof_bits<ElementB>::value;
static constexpr int kUnpackFactor = MmaOperator::FragmentB::kElements / (kWarpIterationsAlongN * kNumPacks);
static constexpr int kUnpackInterval = kExpansionFactor / kUnpackFactor;
/// Unpack 4 uint2b_t values compreseed in a uint8_t to floating points.
using Uint2Converter = FastInterleavedAndBiasedNumericArrayConverter<
ElementOperand, ElementB, MmaOperator::FragmentB::kElements / kUnpackFactor>;
using FragmentInputUnpack = typename Uint2Converter::result_type;
/// Fragment to hold internal scales before Mma
using FragmentScale = Array<ElementOperand, FragmentLocalScale::kElements>;
/// Fragment of dequantized B
using FragmentOutput = Array<ElementOperand, MmaOperator::FragmentB::kElements / kExpansionFactor>;
/// TensorRef type for loading element from a tensor
using SuperTensorRef = cutlass::TensorRef<ElementSuperScale, Layout>;
using LocalTensorRef = cutlass::TensorRef<ElementLocalScale, Layout>;
using CodeTensorRef = cutlass::TensorRef<ElementCodeScaleZp, Layout>;
private:
//
// Data members
//
uint8_t* pointer_local_scale_;
ElementCodeScaleZp* pointer_code_scale_;
ElementCodeScaleZp* pointer_code_zp_;
ElementSuperScale* pointer_super_scale_;
//FragmentInputUnpack unpacked_frag_;
FragmentScale scale_frag_;
public:
CUTLASS_DEVICE
MmaTensorOpWin2xDequantizer(SuperTensorRef smem_super_scale,
LocalTensorRef smem_local_scale,
CodeTensorRef smem_code_scale,
CodeTensorRef smem_code_zp,
int warp_idx_n,
int lane_idx) {
int warp_offset = warp_idx_n * Shape::kN;
int quad = lane_idx / 4;
int thread_offset = warp_offset + quad;
pointer_super_scale_ = smem_super_scale.data() + thread_offset;
pointer_code_scale_ = smem_code_scale.data() + thread_offset;
pointer_code_zp_ = smem_code_zp.data() + thread_offset;
pointer_local_scale_ = reinterpret_cast<uint8_t *>(smem_local_scale.data()) + thread_offset;
}
/// Channel-wise params, need to load just once
CUTLASS_DEVICE
void load(FragmentCodeScaleZp& code_scale_frag,
FragmentCodeScaleZp& code_zp_frag,
FragmentSuperScale& super_scale_frag) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
super_scale_frag[mma_n_iter] = pointer_super_scale_[mma_n_iter * InstructionShape::kN]; // bank conflict
code_scale_frag[mma_n_iter] = pointer_code_scale_[mma_n_iter * InstructionShape::kN];
code_zp_frag[mma_n_iter] = pointer_code_zp_[mma_n_iter * InstructionShape::kN];
}
}
/// Group-wise params, need to load multiple times
CUTLASS_DEVICE
void load(FragmentLocalScale& local_scale_frag) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
local_scale_frag[mma_n_iter] = pointer_local_scale_[mma_n_iter * InstructionShape::kN]; // bank conflict
}
}
CUTLASS_DEVICE
void dequantize(const FragmentLocalScale& local_scale_frag,
const FragmentCodeScaleZp& code_scale_frag,
const FragmentCodeScaleZp& code_zp_frag,
const FragmentSuperScale& super_scale_frag,
const FragmentInput& input_frag,
FragmentOutput& output_frag,
int tb_offset_k,
int warp_k_compute_offset) {
if constexpr (kUnpackInterval != 1) {
// unsupport now
arch::device_breakpoint();
}
typename Uint2Converter::source_type source_frag;
int in_offset = warp_k_compute_offset * kUnpackInterval;
uint8_t const* ptr_input = reinterpret_cast<uint8_t const*>(&input_frag);
uint8_t* ptr_source = reinterpret_cast<uint8_t *>(&source_frag);
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
ptr_source[mma_n_iter] = ptr_input[mma_n_iter * kUnpackFactor + in_offset];
}
FragmentInputUnpack unpacked_frag = Uint2Converter::convert(source_frag, code_scale_frag, code_zp_frag);
// dequantize local_scale
if (warp_k_compute_offset == 0) {
using LocalScaleConverter = detail::LocalScaleConverter<ElementOperand, FragmentLocalScale::kElements>;
// special for TileRows = 64
int local_scale_shift = (((tb_offset_k / kGroupSize) + 1) & 1) * 4;
LocalScaleConverter::Apply(local_scale_frag, super_scale_frag, scale_frag_, local_scale_shift);
}
// unscale
// After applying LOP3 optimizations for performance, the B operand requires data rearrangement.
// reorder: [0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15]
const int kWarpIterationsAlongK = FragmentOutput::kElements / kWarpIterationsAlongN;
using Type = typename detail::DataTypeTraits<ElementOperand>::Type;
using DualType = typename detail::DataTypeTraits<ElementOperand>::DualType;
Type* output_ptr = reinterpret_cast<Type *>(&output_frag);
DualType const* unpacked_ptr = reinterpret_cast<DualType const*>(&unpacked_frag);
DualType const* scale_ptr = reinterpret_cast<DualType const*>(&scale_frag_);
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; mma_n_iter += 2) {
int mapped_idx_base = (mma_n_iter / 2) * kWarpIterationsAlongK;
DualType scalex2 = scale_ptr[mma_n_iter / 2];
CUTLASS_PRAGMA_UNROLL
for (int mma_k_iter = 0; mma_k_iter < kWarpIterationsAlongK; ++mma_k_iter) {
DualType unpacked_valuex2 = unpacked_ptr[mapped_idx_base + mma_k_iter];
DualType scaled_value = __hmul2(unpacked_valuex2, scalex2);
output_ptr[mma_n_iter * kWarpIterationsAlongK + mma_k_iter] = scaled_value.x;
output_ptr[(mma_n_iter + 1) * kWarpIterationsAlongK + mma_k_iter] = scaled_value.y;
}
}
}
/// Add an offset to pointer in units of elements.
/// Only group-wise params needs.
CUTLASS_DEVICE
void add_pointer_offset(int64_t const& offset) {
pointer_local_scale_ += offset;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@@ -39,25 +39,18 @@
#include "cutlass/array.h"
#include "cutlass/half.h"
#include "cutlass/numeric_types.h"
#include "cutlass/trace.h"
namespace cutlass {
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
namespace cutlass
{
// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low
// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally
// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned.
// This converter will uninterleave the data and subtract the bias while converting to the result type.
template <typename T, typename S, int N>
struct FastInterleavedAndBiasedNumericArrayConverter;
struct FastInterleavedAndBiasedNumericArrayConverter
{
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4>
@@ -447,329 +440,6 @@ struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N>
}
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint2b_t, 16>
{
using result_type = Array<half_t, 16>;
using source_type = Array<uint2b_t, 16>;
using ScaleComputeT = float;
using code_type = Array<ScaleComputeT, 4>;
CUTLASS_DEVICE
static result_type convert(source_type const& source, ScaleComputeT code_scale, ScaleComputeT code_zp)
{
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
// 2^23 = 8388608
static constexpr uint32_t FP32_BASE = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
int32_t decode_value[4];
ScaleComputeT new_code_zp = code_zp + 0.5f;
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp));
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp));
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp));
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp));
return convert_impl(decode_value);
}
CUTLASS_DEVICE
static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp)
{
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
// 2^23 = 8388608
static constexpr uint32_t FP32_BASE = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
int32_t decode_value[4];
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f));
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f));
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f));
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f));
return convert_impl(decode_value);
}
CUTLASS_DEVICE
static result_type convert_impl(int32_t* decode_value)
{
result_type result;
static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA;
static constexpr uint32_t MASK = 0x003F003F;
// 2^10 = 1024
static constexpr uint32_t EX = 0x64006400;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410);
int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410);
h[0] = lop3<immLut>(q0 >> 9, MASK, EX);
h[1] = lop3<immLut>(q0 >> 6, MASK, EX);
h[2] = lop3<immLut>(q0 >> 3, MASK, EX);
h[3] = lop3<immLut>(q0, MASK, EX);
h[4] = lop3<immLut>(q1 >> 9, MASK, EX);
h[5] = lop3<immLut>(q1 >> 6, MASK, EX);
h[6] = lop3<immLut>(q1 >> 3, MASK, EX);
h[7] = lop3<immLut>(q1, MASK, EX);
// 1024 + 32 = 1056
static constexpr uint32_t SUB = 0x64206420;
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB));
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s, ScaleComputeT code_scale, ScaleComputeT code_zp)
{
return convert(s, code_scale, code_zp);
}
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint2b_t, 16>
{
using result_type = Array<bfloat16_t, 16>;
using source_type = Array<uint2b_t, 16>;
using ScaleComputeT = float;
using code_type = Array<ScaleComputeT, 4>;
CUTLASS_DEVICE
static result_type convert(source_type const& source, ScaleComputeT code_scale, ScaleComputeT code_zp)
{
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
// 2^23 = 8388608
static constexpr uint32_t FP32_BASE = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
int32_t decode_value[4];
ScaleComputeT new_code_zp = code_zp + 0.5f;
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp));
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp));
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp));
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp));
return convert_impl(decode_value);
}
CUTLASS_DEVICE
static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp)
{
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
// 2^23 = 8388608
static constexpr uint32_t FP32_BASE = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
int32_t decode_value[4];
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f));
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f));
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f));
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f));
return convert_impl(decode_value);
}
CUTLASS_DEVICE
static result_type convert_impl(int32_t* decode_value)
{
result_type result;
static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA;
static constexpr uint32_t MASK = 0x003F003F;
// 2^7 = 128
static constexpr uint32_t EX = 0x43004300;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410);
int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410);
h[0] = lop3<immLut>(q0 >> 9, MASK, EX);
h[1] = lop3<immLut>(q0 >> 6, MASK, EX);
h[2] = lop3<immLut>(q0 >> 3, MASK, EX);
h[3] = lop3<immLut>(q0, MASK, EX);
h[4] = lop3<immLut>(q1 >> 9, MASK, EX);
h[5] = lop3<immLut>(q1 >> 6, MASK, EX);
h[6] = lop3<immLut>(q1 >> 3, MASK, EX);
h[7] = lop3<immLut>(q1, MASK, EX);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(ENABLE_BF16))
// 128 + 32 = 160
static constexpr uint32_t SUB = 0x43204320;
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB));
#else
// 1.0
static constexpr uint32_t MUL = 0x3F803F80;
// -160
static constexpr uint32_t ADD = 0xC320C320;
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[4]) : "r"(h[4]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[5]) : "r"(h[5]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[6]) : "r"(h[6]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[7]) : "r"(h[7]), "r"(MUL), "r"(ADD));
#endif
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s, ScaleComputeT code_scale, ScaleComputeT code_zp)
{
return convert(s, code_scale, code_zp);
}
};
template <typename T, int N>
struct FastInterleavedAndBiasedNumericArrayConverter<T, uint2b_t, N>
{
static_assert(platform::is_same<T, half_t>::value || platform::is_same<T, bfloat16_t>::value,
"T must be fp16 or bf16");
static constexpr int kVecWidth = 16;
static_assert(!(N % kVecWidth), "N must be multiple of 16.");
using result_type = Array<T, N>;
using source_type = Array<uint2b_t, N>;
using code_type = Array<float, N / kVecWidth>;
CUTLASS_DEVICE
static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp)
{
using scalar_result_type = typename result_type::Element;
using scalar_source_type = typename source_type::Element;
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, kVecWidth>
convert_vector_;
result_type result;
using vec_result = Array<scalar_result_type, kVecWidth>;
using vec_source = Array<scalar_source_type, kVecWidth>;
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / kVecWidth; ++i)
{
result_ptr[i] = convert_vector_(source_ptr[i], code_scale[i], code_zp[i]);
}
return result;
}
CUTLASS_DEVICE
static result_type convert(source_type const& source, Array<float, N / 4> const& code_scale, Array<float, N / 4> const& code_zp)
{
using scalar_result_type = typename result_type::Element;
using scalar_source_type = typename source_type::Element;
using Converter = FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, kVecWidth>;
result_type result;
using vec_result = typename Converter::result_type;
using vec_source = typename Converter::source_type;
using vec_code = typename Converter::code_type;
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
vec_code const* code_scale_ptr = reinterpret_cast<vec_code const*>(&code_scale);
vec_code const* code_zp_ptr = reinterpret_cast<vec_code const*>(&code_zp);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / kVecWidth; ++i)
{
result_ptr[i] = Converter::convert(source_ptr[i], code_scale_ptr[i], code_zp_ptr[i]);
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s, code_type const& code_scale, code_type const& code_zp)
{
return convert(s, code_scale, code_zp);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@@ -125,13 +125,10 @@ struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt2> {
static constexpr int32_t kNumPackedValues = 4;
static constexpr int32_t kPackedSize = 16;
using LocalScaleType = uint4b_t;
using CodeScaleZpType = float;
struct Arguments {
uint8_t *local_scale_ptr; // quanted 4-bits
float *code_scale_ptr;
float *code_zp_ptr;
const uint8_t *local_scale_ptr; // quanted 4-bits
const float *code_scale_ptr;
const float *code_zp_ptr;
};
CUTLASS_DEVICE

View File

@@ -117,7 +117,7 @@ class LeftGELUAndMul {
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const &lhs,
FragmentAccumulator const &rhs) const {
// Convert source to internal compute numeric type
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_to_compute;

View File

@@ -117,7 +117,7 @@ class LeftSiLUAndMul {
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const &lhs,
FragmentAccumulator const &rhs) const {
// Convert source to internal compute numeric type
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_to_compute;

View File

@@ -92,7 +92,7 @@ class DualMmaBase {
Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM operations
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator0::Policy::MmaShape::kK);

View File

@@ -43,6 +43,7 @@
#include "cutlass/trace.h"
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -774,54 +775,17 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
template <WintQuantMethod QuantMethod, typename dummy>
struct KernelRunner<QuantMethod, true, dummy> {
using WeightQuantTraits = WintQuantTraits<ElementA, QuantMethod>;
using MmaQuantArguments = typename Mma::QuantParamsAccessor::Arguments;
using QuantArguments = typename WeightQuantTraits::Arguments;
CUTLASS_DEVICE
static MmaQuantArguments prepare_quant_args(
Params const& params, cutlass::gemm::GemmCoord const& threadblock_offset,
int64_t problem_idx, const int32_t gemm_k, const int32_t gemm_n, const int thread_idx) {
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
cutlass::MatrixCoord tb_offset_local_scale{0, threadblock_offset.n() * 2};
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * gemm_n;
typename Mma::QuantParamsAccessor::IteratorSuperScale iterator_super_scale(
Mma::QuantParamsAccessor::LayoutSuperScale(gemm_n),
weight_scale_ptr,
{1, gemm_n},
thread_idx,
tb_offset_scale);
int local_scale_pointer_offset = ((ThreadblockShape::kK + 127) / 128) * (gemm_n * 2);
int64_t offset_in_bytes = problem_idx * gemm_k * gemm_n / 128;
uint4b_t *local_scale_ptr = reinterpret_cast<uint4b_t *>(params.local_scale + offset_in_bytes);
typename Mma::QuantParamsAccessor::IteratorLocalScale iterator_local_scale(
Mma::QuantParamsAccessor::LayoutLocalScale(gemm_n * 2),
local_scale_ptr,
{(gemm_k + 127) / 128, gemm_n * 2},
thread_idx,
tb_offset_local_scale);
float* code_scale_ptr = params.code_scale + problem_idx * gemm_n;
typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_scale(
Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n),
code_scale_ptr,
{1, gemm_n},
thread_idx,
tb_offset_scale);
float* code_zp_ptr = params.code_zp + problem_idx * gemm_n;
typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_zp(
Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n),
code_zp_ptr,
{1, gemm_n},
thread_idx,
tb_offset_scale);
MmaQuantArguments mma_quant_args(
iterator_super_scale, iterator_local_scale, iterator_code_scale, iterator_code_zp, local_scale_pointer_offset);
return mma_quant_args;
static QuantArguments get_quant_args(Params const& params, int32_t problem_idx, const int64_t gemm_k, const int64_t gemm_n) {
QuantArguments quant_args;
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
quant_args.local_scale_ptr = params.local_scale + problem_idx * gemm_k * gemm_n / 128;
quant_args.code_scale_ptr = params.code_scale + problem_idx * gemm_n;
quant_args.code_zp_ptr = params.code_zp + problem_idx * gemm_n;
}
return quant_args;
}
CUTLASS_DEVICE
@@ -850,6 +814,9 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
// LayoutB should be RowMajor
using TileDequanterB = cutlass::gemm::threadblock::TileDequanter<ElementA, ElementScale, ThreadblockShape::kK, ThreadblockShape::kN, kStages, kThreadCount, QuantMethod>;
//
// Problem visitor.
//
@@ -876,6 +843,12 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT
0);
// begin address offset for weight_scale.
ElementScale* weight_scale_ptr =
params.weight_scales ? params.weight_scales + problem_idx * problem_size.n() : nullptr;
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
// Load element pointers. Exchange pointers and strides if working on
// the transpose
int64_t rows_to_jump = 0;
@@ -893,20 +866,42 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
// Compute initial location in logical coordinates
// the begin threadblock_offset of A, which holds the same row id with C
cutlass::MatrixCoord tb_offset_A{threadblock_offset.m(), 0};
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
0,
};
// begin address offset for B for current problem_idx, totally num_experts problems
char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT
problem_idx * bytes_per_expert_matrix; // NOLINT
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
typename LayoutB::LongIndex ldm_B =
platform::is_same<layout::RowMajor, LayoutB>::value
? gemm_n
: gemm_k * kInterleave;
typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns;
// the begin threadblock_offset of B, which holds the same column id with C
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
cutlass::MatrixCoord tb_offset_B{0,
threadblock_offset.n() / kInterleave};
cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave};
cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns};
MmaElementB* smem_unzip_B_ptr = nullptr;
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr();
}
QuantArguments quant_args = get_quant_args(params, problem_idx, gemm_k, gemm_n);
TileDequanterB tile_dequanter_B(smem_unzip_B_ptr,
byte_ptr_B,
ldm_B,
extent_B,
tb_offset_B,
weight_scale_ptr,
tb_offset_scale,
quant_args);
MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr();
// Compute position within threadblock
int thread_idx = threadIdx.x;
@@ -919,21 +914,20 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
tb_offset_A);
typename Mma::IteratorB iterator_B(
LayoutB(ldm_B),
LayoutB(TileDequanterB::kUseSharedMemory ? ldm_B_shared : ldm_B),
ptr_B,
extent_B,
TileDequanterB::kUseSharedMemory ? extent_B_shared : extent_B,
thread_idx,
tb_offset_B);
MmaQuantArguments mma_quant_args = prepare_quant_args(
params, threadblock_offset, problem_idx, gemm_k, gemm_n, thread_idx);
TileDequanterB::kUseSharedMemory ? cutlass::make_Coord(0, 0) : tb_offset_B);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
@@ -956,7 +950,7 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
accumulators,
iterator_A,
iterator_B,
mma_quant_args,
tile_dequanter_B,
accumulators);
//

View File

@@ -205,7 +205,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
threadblock_count,
epilogue_op,
reinterpret_cast<const ElementType*>(A),
reinterpret_cast<const CutlassMmaKernelType*>(B),
reinterpret_cast<const CutlassMmaWeightType*>(B),
reinterpret_cast<const ElementType*>(weight_scales),
reinterpret_cast<const ElementType*>(biases),
reinterpret_cast<ElementType*>(C),

View File

@@ -219,7 +219,7 @@ class EpilogueVisitorPerRowPerColNf4 {
iterator_C_.clear_mask();
}
// NOTE(wangbojun) Currently, this kernel don't hanve implantention for
// adding elementwise beta, we keep this here for future usage beta_ =
// adding elementwise beta, we keep this here for future useage beta_ =
// (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr :
// params.elementwise.beta); if (beta_ == ElementAccumulator()) {
// iterator_C_.clear_mask();

View File

@@ -176,7 +176,7 @@ struct Nf4DefaultIteratorsTensorOp<cutlass::bfloat16_t,
///
/// Satisfies: ReadableTileIterator
///
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
template <typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap)
>
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
public:

Some files were not shown because too many files have changed in this diff Show More