mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
Compare commits
23 Commits
copilot/fi
...
release/2.
Author | SHA1 | Date | |
---|---|---|---|
![]() |
bd30b08521 | ||
![]() |
1aa16146ba | ||
![]() |
dac0a00d0f | ||
![]() |
c5591c45df | ||
![]() |
121ac85d7d | ||
![]() |
d233e3c97c | ||
![]() |
2136990144 | ||
![]() |
b7890cbe8d | ||
![]() |
bc388b65c7 | ||
![]() |
71af0ca04a | ||
![]() |
d66660a0d1 | ||
![]() |
f0519aec67 | ||
![]() |
1f5983290c | ||
![]() |
c6a133d573 | ||
![]() |
4646aff25c | ||
![]() |
a84a98b107 | ||
![]() |
c208086f61 | ||
![]() |
ce1d4944e7 | ||
![]() |
5439fb6336 | ||
![]() |
a592d17615 | ||
![]() |
eca8fc7ca6 | ||
![]() |
0463797fc2 | ||
![]() |
0ab8645fc4 |
6
.github/workflows/Codestyle-Check.yml
vendored
6
.github/workflows/Codestyle-Check.yml
vendored
@@ -2,9 +2,7 @@ name: Codestyle-Check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
branches: ["develop"]
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
@@ -13,7 +11,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
PR_ID: ${{ github.event.pull_request.number }}
|
||||
BRANCH: ${{ github.event.pull_request.base.ref }}
|
||||
BRANCH: develop
|
||||
|
||||
steps:
|
||||
- name: Cleanup
|
||||
|
187
.github/workflows/_accuracy_test.yml
vendored
187
.github/workflows/_accuracy_test.yml
vendored
@@ -1,187 +0,0 @@
|
||||
name: Accuracy Test
|
||||
description: "Run Accuracy Tests"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
DOCKER_IMAGE:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
|
||||
FASTDEPLOY_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
FASTDEPLOY_WHEEL_URL:
|
||||
description: "URL of the FastDeploy Wheel."
|
||||
required: true
|
||||
type: string
|
||||
CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
MODEL_CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
|
||||
jobs:
|
||||
accuracy_tests:
|
||||
runs-on: [self-hosted, GPU-h20-1Cards]
|
||||
timeout-minutes: 60
|
||||
steps:
|
||||
- name: Code Prepare
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
run: |
|
||||
set -x
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}*
|
||||
fi
|
||||
'
|
||||
|
||||
wget -q ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
rm -rf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
|
||||
- name: Run FastDeploy Base Tests
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
|
||||
CACHE_DIR: ${{ inputs.CACHE_DIR }}
|
||||
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
|
||||
|
||||
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
|
||||
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
PIDS=$(lsof -t -i :$port || true)
|
||||
if [ -n "$PIDS" ]; then
|
||||
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
|
||||
echo "$PIDS" | xargs -r kill -9
|
||||
echo "Port $port cleared" | tee -a $LOG_FILE
|
||||
else
|
||||
echo "Port $port is free" | tee -a $LOG_FILE
|
||||
fi
|
||||
done
|
||||
|
||||
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
|
||||
|
||||
echo "========================================================="
|
||||
echo "Ensuring no stale container named ${runner_name} ..."
|
||||
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
|
||||
echo "Removing stale container: ${runner_name}"
|
||||
docker rm -f ${runner_name} || true
|
||||
fi
|
||||
|
||||
docker run --rm --ipc=host --pid=host --net=host \
|
||||
--name ${runner_name} \
|
||||
-v $(pwd):/workspace \
|
||||
-w /workspace \
|
||||
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install ${fastdeploy_wheel_url}
|
||||
python -m pip install pytest
|
||||
|
||||
wget https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64
|
||||
chmod +x ./llm-deploy-linux-amd64
|
||||
./llm-deploy-linux-amd64 -python python3.10 \
|
||||
-model_name ERNIE-4.5-0.3B-Paddle \
|
||||
-model_path /MODELDATA \
|
||||
--skip install
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
pushd tests/ce/deploy
|
||||
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
|
||||
python3.10 deploy.py > dd.log 2>&1 &
|
||||
sleep 3
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}"
|
||||
|
||||
curl -X POST http://localhost:${FLASK_PORT}/wait_for_infer?timeout=90
|
||||
popd
|
||||
|
||||
pushd tests/ce/accuracy_cases
|
||||
export URL=http://localhost:${FD_API_PORT}/v1/chat/completions
|
||||
export TEMPLATE=TOKEN_LOGPROB
|
||||
export MODEL_SIZE=0.3B
|
||||
TEST_EXIT_CODE=0
|
||||
python gsm8k.py || TEST_EXIT_CODE=1
|
||||
popd
|
||||
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> /workspace/FastDeploy/exit_code.env
|
||||
'
|
||||
if [ -f ./FastDeploy/exit_code.env ]; then
|
||||
source ./FastDeploy/exit_code.env
|
||||
cat ./FastDeploy/exit_code.env >> $GITHUB_ENV
|
||||
fi
|
||||
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}"
|
||||
exit ${TEST_EXIT_CODE}
|
230
.github/workflows/_base_test.yml
vendored
230
.github/workflows/_base_test.yml
vendored
@@ -1,230 +0,0 @@
|
||||
name: Base Test
|
||||
description: "Run Base Tests"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
DOCKER_IMAGE:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
|
||||
FASTDEPLOY_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
FASTDEPLOY_WHEEL_URL:
|
||||
description: "URL of the FastDeploy Wheel."
|
||||
required: true
|
||||
type: string
|
||||
CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
MODEL_CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
|
||||
jobs:
|
||||
base_tests:
|
||||
runs-on: [self-hosted, GPU-h20-1Cards]
|
||||
timeout-minutes: 60
|
||||
steps:
|
||||
- name: Code Prepare
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
run: |
|
||||
set -x
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}*
|
||||
fi
|
||||
'
|
||||
|
||||
wget -q ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
rm -rf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
|
||||
- name: Run FastDeploy Base Tests
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
|
||||
CACHE_DIR: ${{ inputs.CACHE_DIR }}
|
||||
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
|
||||
|
||||
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
|
||||
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
PIDS=$(lsof -t -i :$port || true)
|
||||
if [ -n "$PIDS" ]; then
|
||||
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
|
||||
echo "$PIDS" | xargs -r kill -9
|
||||
echo "Port $port cleared" | tee -a $LOG_FILE
|
||||
else
|
||||
echo "Port $port is free" | tee -a $LOG_FILE
|
||||
fi
|
||||
done
|
||||
|
||||
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
|
||||
|
||||
echo "========================================================="
|
||||
echo "Ensuring no stale container named ${runner_name} ..."
|
||||
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
|
||||
echo "Removing stale container: ${runner_name}"
|
||||
docker rm -f ${runner_name} || true
|
||||
fi
|
||||
|
||||
docker run --rm --ipc=host --pid=host --net=host \
|
||||
--name ${runner_name} \
|
||||
-v $(pwd):/workspace \
|
||||
-w /workspace \
|
||||
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
|
||||
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install ${fastdeploy_wheel_url}
|
||||
python -m pip install pytest
|
||||
|
||||
wget https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64
|
||||
chmod +x ./llm-deploy-linux-amd64
|
||||
./llm-deploy-linux-amd64 -python python3.10 \
|
||||
-model_name ERNIE-4.5-0.3B-Paddle \
|
||||
-model_path /MODELDATA \
|
||||
--skip install
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
pushd tests/ce/deploy
|
||||
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
|
||||
python3.10 deploy.py > dd.log 2>&1 &
|
||||
sleep 3
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}"
|
||||
|
||||
check_service() {
|
||||
local timeout=${1:-90}
|
||||
local url="http://localhost:${FLASK_PORT}/wait_for_infer?timeout=${timeout}"
|
||||
local resp
|
||||
|
||||
resp=$(curl -s -X POST "$url")
|
||||
|
||||
if echo "$resp" | grep -q "服务启动超时"; then
|
||||
exit 8
|
||||
fi
|
||||
}
|
||||
|
||||
check_service 90
|
||||
popd
|
||||
|
||||
pushd tests/ce/server
|
||||
export URL=http://localhost:${FD_API_PORT}/v1/chat/completions
|
||||
export TEMPLATE=TOKEN_LOGPROB
|
||||
TEST_EXIT_CODE=0
|
||||
python -m pytest -sv test_base_chat.py test_compare_top_logprobs.py test_logprobs.py test_params_boundary.py test_seed_usage.py test_stream.py test_evil_cases.py test_completions.py test_return_token_ids.py || TEST_EXIT_CODE=1
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--early-stop-config\": \"{\\\"enable_early_stop\\\":true, \\\"window_size\\\":6, \\\"threshold\\\":0.93}\"}"
|
||||
check_service 90
|
||||
python -m pytest -sv test_repetition_early_stop.py || TEST_EXIT_CODE=1
|
||||
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--max-concurrency\": 5, \"--max-waiting-time\": 1 }"
|
||||
check_service 90
|
||||
python -m pytest -sv test_max_concurrency.py || TEST_EXIT_CODE=1
|
||||
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--max-concurrency\": 5000, \"--max-waiting-time\": 1 }"
|
||||
check_service 90
|
||||
python -m pytest -sv test_max_waiting_time.py || TEST_EXIT_CODE=1
|
||||
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"--model\": \"/MODELDATA/ernie-4_5-21b-a3b-bf16-paddle\", \"--config\": \"21b_mtp.yaml\", \"--enable-logprob\": \"False\"}"
|
||||
check_service 180
|
||||
export TEMPLATE=TOKEN_NORMAL
|
||||
python -m pytest -sv test_seed_usage.py -k "not test_seed_stream" || TEST_EXIT_CODE=1
|
||||
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"--model\": \"/MODELDATA/ernie-4_5-21b-a3b-bf16-paddle\", \"--config\": \"21b_sot.yaml\", \"--enable-logprob\": \"False\"}"
|
||||
check_service 360
|
||||
export TEMPLATE=TOKEN_NORMAL
|
||||
python -m pytest -sv test_seed_usage.py -k "not test_seed_stream" || TEST_EXIT_CODE=1
|
||||
|
||||
popd
|
||||
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> /workspace/FastDeploy/exit_code.env
|
||||
'
|
||||
if [ -f ./FastDeploy/exit_code.env ]; then
|
||||
source ./FastDeploy/exit_code.env
|
||||
cat ./FastDeploy/exit_code.env >> $GITHUB_ENV
|
||||
fi
|
||||
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}"
|
||||
exit ${TEST_EXIT_CODE}
|
42
.github/workflows/_build_linux.yml
vendored
42
.github/workflows/_build_linux.yml
vendored
@@ -22,22 +22,12 @@ on:
|
||||
description: "Enable nightly build mode (e.g. add date suffix to version)"
|
||||
required: false
|
||||
type: string
|
||||
default: "OFF"
|
||||
default: "ON"
|
||||
FD_VERSION:
|
||||
description: "FastDeploy Package Version"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
PADDLEVERSION:
|
||||
description: "Paddle Version Build Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
PADDLE_WHL_URL:
|
||||
description: "Paddle Wheel Package URL"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
UPLOAD:
|
||||
description: "Upload Package"
|
||||
required: false
|
||||
@@ -54,8 +44,7 @@ on:
|
||||
value: ${{ jobs.fd-build.outputs.wheel_path }}
|
||||
jobs:
|
||||
fd-build:
|
||||
runs-on: [self-hosted, GPU-Build]
|
||||
timeout-minutes: 360
|
||||
runs-on: [self-hosted, GPU-h1z1-4Cards]
|
||||
outputs:
|
||||
wheel_path: ${{ steps.set_output.outputs.wheel_path }}
|
||||
steps:
|
||||
@@ -96,17 +85,13 @@ jobs:
|
||||
compile_arch: ${{ inputs.COMPILE_ARCH }}
|
||||
fd_version: ${{ inputs.FD_VERSION }}
|
||||
CACHE_DIR: ${{ inputs.CACHE_DIR }}
|
||||
BRANCH_REF: ${{ github.ref_name }}
|
||||
PADDLEVERSION: ${{ inputs.PADDLEVERSION }}
|
||||
PADDLE_WHL_URL: ${{ inputs.PADDLE_WHL_URL }}
|
||||
WITH_NIGHTLY_BUILD: ${{ inputs.WITH_NIGHTLY_BUILD }}
|
||||
run: |
|
||||
set -x
|
||||
runner_name="${{ runner.name }}"
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
CARD_ID=$(echo "${runner_name}" | cut -d'-' -f2)
|
||||
gpu_id=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
CACHE_DIR=${CACHE_DIR:-${{ github.workspace }}}
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
@@ -118,15 +103,11 @@ jobs:
|
||||
-v $(pwd):/workspace -w /workspace \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/.ccache:/root/.ccache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
-e "COMPILE_ARCH=${compile_arch}" \
|
||||
-e "FD_VERSION=${fd_version}" \
|
||||
-e "WITH_NIGHTLY_BUILD=${WITH_NIGHTLY_BUILD}" \
|
||||
-e "PADDLEVERSION=${PADDLEVERSION}" \
|
||||
-e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \
|
||||
-e "BRANCH_REF=${BRANCH_REF}" \
|
||||
--gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c '
|
||||
if [[ -n "${FD_VERSION}" ]]; then
|
||||
export FASTDEPLOY_VERSION=${FD_VERSION}
|
||||
@@ -134,7 +115,6 @@ jobs:
|
||||
fi
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
chown -R $(whoami) /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
if [[ "${WITH_NIGHTLY_BUILD}" == "ON" ]];then
|
||||
GIT_COMMIT_TIME=$(git --no-pager show -s --format=%ci HEAD)
|
||||
@@ -143,20 +123,14 @@ jobs:
|
||||
echo "Date Only: $DATE_ONLY"
|
||||
export FASTDEPLOY_VERSION="${FASTDEPLOY_VERSION}.dev${DATE_ONLY}"
|
||||
fi
|
||||
# 针对不同分支和tag使用不同的PaddlePaddle安装包
|
||||
if [[ "${PADDLE_WHL_URL}" != "" ]];then
|
||||
python -m pip install ${PADDLE_WHL_URL}
|
||||
elif [[ "${PADDLEVERSION}" != "" ]];then
|
||||
python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
else
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
fi
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
pip config set global.index-url http://pip.baidu.com/root/baidu/+simple/
|
||||
pip config set install.trusted-host pip.baidu.com
|
||||
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install wheel
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
# 编译RDMA
|
||||
export ENABLE_FD_RDMA=1
|
||||
bash build.sh 1 python false [${COMPILE_ARCH}]
|
||||
|
73
.github/workflows/_ci_image_build.yml
vendored
73
.github/workflows/_ci_image_build.yml
vendored
@@ -1,73 +0,0 @@
|
||||
name: Docker Build
|
||||
description: "FastDeploy CI Image Build"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
CI_DOCKER_IMAGE_NAME:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
|
||||
FASTDEPLOY_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
DOCKER_IMAGE_NAME:
|
||||
description: "Build Images"
|
||||
required: false
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate"
|
||||
outputs:
|
||||
docker_name_precheck:
|
||||
description: "Output path of the generated wheel"
|
||||
value: ${{ jobs.docker_build.outputs.docker_name_precheck }}
|
||||
|
||||
jobs:
|
||||
docker_build:
|
||||
runs-on: [self-hosted, Docker-Build]
|
||||
outputs:
|
||||
docker_name_precheck: ${{ steps.docker_build.outputs.docker_name_precheck }}
|
||||
steps:
|
||||
- name: Docker Build
|
||||
id: docker_build
|
||||
shell: bash
|
||||
env:
|
||||
docker_image_name: ${{ inputs.CI_DOCKER_IMAGE_NAME }}
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE_NAME }}
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
run: |
|
||||
set -x
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}*
|
||||
fi
|
||||
'
|
||||
|
||||
wget -q ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
rm -rf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
|
||||
# Docker Build
|
||||
cd tools/dockerfile/
|
||||
set -e
|
||||
cp ../../requirements.txt ./
|
||||
cp ../../scripts/unittest_requirement.txt ./
|
||||
docker build -t ${docker_image_name} -f Dockerfile.ci . \
|
||||
--network host \
|
||||
--no-cache
|
||||
docker push ${docker_image_name}
|
||||
echo "docker_name_precheck=${docker_image_name}" >> $GITHUB_OUTPUT
|
2
.github/workflows/_clone_linux.yml
vendored
2
.github/workflows/_clone_linux.yml
vendored
@@ -68,7 +68,7 @@ jobs:
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-github-action/BRANCH/FastDeploy/${branch_name}/${commit_id}
|
||||
fi
|
||||
wget -O bos_tools.py -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
|
185
.github/workflows/_logprob_test_linux.yml
vendored
185
.github/workflows/_logprob_test_linux.yml
vendored
@@ -1,185 +0,0 @@
|
||||
name: Run FastDeploy LogProb Tests
|
||||
description: "Run FastDeploy LogProb Tests"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
DOCKER_IMAGE:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
|
||||
PADDLETEST_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
default: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
|
||||
FASTDEPLOY_WHEEL_URL:
|
||||
description: "URL of the FastDeploy Wheel."
|
||||
required: true
|
||||
type: string
|
||||
CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
MODEL_CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
|
||||
jobs:
|
||||
run_tests_logprob:
|
||||
runs-on: [self-hosted, GPU-h20-1Cards]
|
||||
steps:
|
||||
- name: Code Prepare
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
paddletest_archive_url: ${{ inputs.PADDLETEST_ARCHIVE_URL }}
|
||||
run: |
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
-e "BASE_BRANCH=${BASE_BRANCH}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
rm -rf /workspace/*
|
||||
'
|
||||
wget -q ${paddletest_archive_url}
|
||||
tar -xf PaddleTest.tar.gz
|
||||
rm -rf PaddleTest.tar.gz
|
||||
cd PaddleTest
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
- name: logprob test
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
|
||||
CACHE_DIR: ${{ inputs.CACHE_DIR }}
|
||||
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
|
||||
|
||||
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
|
||||
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
PIDS=$(lsof -t -i :$port || true)
|
||||
if [ -n "$PIDS" ]; then
|
||||
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
|
||||
echo "$PIDS" | xargs -r kill -9
|
||||
echo "Port $port cleared" | tee -a $LOG_FILE
|
||||
else
|
||||
echo "Port $port is free" | tee -a $LOG_FILE
|
||||
fi
|
||||
done
|
||||
|
||||
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
|
||||
|
||||
echo "========================================================="
|
||||
echo "Ensuring no stale container named ${runner_name} ..."
|
||||
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
|
||||
echo "Removing stale container: ${runner_name}"
|
||||
docker rm -f ${runner_name} || true
|
||||
fi
|
||||
docker run --rm --ipc=host --pid=host --net=host \
|
||||
--name ${runner_name} \
|
||||
-v $(pwd):/workspace \
|
||||
-w /workspace \
|
||||
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install ${fastdeploy_wheel_url}
|
||||
|
||||
wget https://paddle-qa.bj.bcebos.com/zhengtianyu/tools/llm-deploy-linux-amd64
|
||||
chmod +x ./llm-deploy-linux-amd64
|
||||
./llm-deploy-linux-amd64 -python python3.10 \
|
||||
-model_name ERNIE-4.5-0.3B-Paddle \
|
||||
-model_path /MODELDATA \
|
||||
--skip install
|
||||
|
||||
cd PaddleTest/framework/ServeTest
|
||||
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
|
||||
python3.10 deploy.py > dd.log 2>&1 &
|
||||
sleep 3
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\"}"
|
||||
|
||||
curl -X POST http://localhost:${FLASK_PORT}/wait_for_infer?timeout=90
|
||||
curl -s -o /dev/null -w "%{http_code}" -m 2 "http://0.0.0.0:${FD_API_PORT}/health"
|
||||
curl -X POST "http://0.0.0.0:${FD_API_PORT}/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"messages\": [{\"role\": \"user\", \"content\": \"1+1=?\"}], \"logprobs\": true}"
|
||||
set +e
|
||||
rm -rf ./baseline_output
|
||||
cp -r baseline/ERNIE-4.5-0.3B-Paddle ./baseline_output
|
||||
LOGPROB_EXIT_CODE=0
|
||||
python3.10 lanucher.py --request_template TOKEN_LOGPROB --url http://localhost:${FD_API_PORT}/v1/chat/completions --case ./cases/demo.yaml --concurrency 1 --name demo --exe logprob || LOGPROB_EXIT_CODE=$?
|
||||
echo "LOGPROB_EXIT_CODE=${LOGPROB_EXIT_CODE}" > /workspace/exit_code.env
|
||||
curl -X POST http://localhost:${FLASK_PORT}/stop
|
||||
sleep 10s
|
||||
cat *result.log
|
||||
exit 0
|
||||
'
|
||||
if [ $? -ne 0 ];then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -f exit_code.env ]; then
|
||||
cat exit_code.env >> $GITHUB_ENV
|
||||
fi
|
||||
- name: logprob test result
|
||||
if: ${{ env.LOGPROB_EXIT_CODE != 0 }}
|
||||
shell: bash
|
||||
run: |
|
||||
echo "logprob test failed with exit code ${{ env.LOGPROB_EXIT_CODE }}"
|
||||
exit 8
|
148
.github/workflows/_pre_ce_test.yml
vendored
148
.github/workflows/_pre_ce_test.yml
vendored
@@ -1,148 +0,0 @@
|
||||
name: Pre-CE-Test
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
DOCKER_IMAGE:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126"
|
||||
FASTDEPLOY_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
FASTDEPLOY_WHEEL_URL:
|
||||
description: "URL of the FastDeploy Wheel."
|
||||
required: true
|
||||
type: string
|
||||
CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
MODEL_CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
|
||||
jobs:
|
||||
run_ce_cases:
|
||||
runs-on: [self-hosted, PRE_CE_RUN_2Card]
|
||||
timeout-minutes: 60
|
||||
steps:
|
||||
- name: Print current runner name
|
||||
run: |
|
||||
echo "Current runner name: ${{ runner.name }}"
|
||||
- name: Code Prepare
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
run: |
|
||||
set -x
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}*
|
||||
fi
|
||||
'
|
||||
|
||||
wget -q ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
rm -rf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
|
||||
- name: Run CI unittest
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
|
||||
CACHE_DIR: ${{ inputs.CACHE_DIR }}
|
||||
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
|
||||
|
||||
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
PIDS=$(lsof -t -i :$port || true)
|
||||
if [ -n "$PIDS" ]; then
|
||||
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
|
||||
echo "$PIDS" | xargs -r kill -9
|
||||
echo "Port $port cleared" | tee -a $LOG_FILE
|
||||
else
|
||||
echo "Port $port is free" | tee -a $LOG_FILE
|
||||
fi
|
||||
done
|
||||
|
||||
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
|
||||
|
||||
echo "========================================================="
|
||||
echo "Ensuring no stale container named ${runner_name} ..."
|
||||
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
|
||||
echo "Removing stale container: ${runner_name}"
|
||||
docker rm -f ${runner_name} || true
|
||||
fi
|
||||
|
||||
docker run --rm --net=host \
|
||||
--name ${runner_name} \
|
||||
-v $(pwd):/workspace \
|
||||
-w /workspace \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-v "${MODEL_CACHE_DIR}:/ModelData:ro" \
|
||||
-e "MODEL_PATH=/ModelData" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-e "fd_wheel_url=${fd_wheel_url}" \
|
||||
--gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c '
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
python -m pip install ${fd_wheel_url}
|
||||
bash scripts/run_pre_ce.sh
|
||||
'
|
170
.github/workflows/_stable_test.yml
vendored
170
.github/workflows/_stable_test.yml
vendored
@@ -1,170 +0,0 @@
|
||||
name: Stable Test
|
||||
description: "Run Stable Tests"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
DOCKER_IMAGE:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
|
||||
FASTDEPLOY_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
FASTDEPLOY_WHEEL_URL:
|
||||
description: "URL of the FastDeploy Wheel."
|
||||
required: true
|
||||
type: string
|
||||
CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
MODEL_CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
|
||||
jobs:
|
||||
stable_tests:
|
||||
runs-on: [self-hosted, GPU-h1z1-2Cards]
|
||||
timeout-minutes: 60
|
||||
steps:
|
||||
- name: Code Prepare
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
run: |
|
||||
set -x
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}*
|
||||
fi
|
||||
'
|
||||
|
||||
wget -q ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
rm -rf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
|
||||
- name: Run FastDeploy Stable Tests
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fastdeploy_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
|
||||
CACHE_DIR: ${{ inputs.CACHE_DIR }}
|
||||
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
|
||||
|
||||
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42038 + DEVICE_PORT * 100))
|
||||
FD_INFERENCE_MSG_QUEUE_ID=$(( 42048 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "FD_INFERENCE_MSG_QUEUE_ID=${FD_INFERENCE_MSG_QUEUE_ID}"
|
||||
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
if [ ! -d "${MODEL_CACHE_DIR}" ]; then
|
||||
echo "Error: MODEL_CACHE_DIR '${MODEL_CACHE_DIR}' does not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
PIDS=$(lsof -t -i :$port || true)
|
||||
if [ -n "$PIDS" ]; then
|
||||
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
|
||||
echo "$PIDS" | xargs -r kill -9
|
||||
echo "Port $port cleared" | tee -a $LOG_FILE
|
||||
else
|
||||
echo "Port $port is free" | tee -a $LOG_FILE
|
||||
fi
|
||||
done
|
||||
|
||||
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
|
||||
|
||||
echo "========================================================="
|
||||
echo "Ensuring no stale container named ${runner_name} ..."
|
||||
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
|
||||
echo "Removing stale container: ${runner_name}"
|
||||
docker rm -f ${runner_name} || true
|
||||
fi
|
||||
|
||||
docker run --rm --ipc=host --pid=host --net=host \
|
||||
--name ${runner_name} \
|
||||
-v $(pwd):/workspace \
|
||||
-w /workspace \
|
||||
-e fastdeploy_wheel_url=${fastdeploy_wheel_url} \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-e "FD_INFERENCE_MSG_QUEUE_ID=${FD_INFERENCE_MSG_QUEUE_ID}" \
|
||||
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
|
||||
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install ${fastdeploy_wheel_url}
|
||||
python -m pip install pytest
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
TEST_EXIT_CODE=0
|
||||
pushd tests/ce/stable_cases
|
||||
bash launch_model.sh /MODELDATA
|
||||
bash run.sh || TEST_EXIT_CODE=1
|
||||
popd
|
||||
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> /workspace/FastDeploy/exit_code.env
|
||||
'
|
||||
if [ -f ./FastDeploy/exit_code.env ]; then
|
||||
source ./FastDeploy/exit_code.env
|
||||
cat ./FastDeploy/exit_code.env >> $GITHUB_ENV
|
||||
fi
|
||||
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}"
|
||||
exit ${TEST_EXIT_CODE}
|
319
.github/workflows/_unit_test_coverage.yml
vendored
319
.github/workflows/_unit_test_coverage.yml
vendored
@@ -1,319 +0,0 @@
|
||||
name: Coverage Check
|
||||
description: "Run FastDeploy Unit Tests and Coverage"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
DOCKER_IMAGE:
|
||||
description: "Build Images"
|
||||
required: true
|
||||
type: string
|
||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
|
||||
FASTDEPLOY_ARCHIVE_URL:
|
||||
description: "URL of the compressed FastDeploy code archive."
|
||||
required: true
|
||||
type: string
|
||||
FASTDEPLOY_WHEEL_URL:
|
||||
description: "URL of the FastDeploy Wheel."
|
||||
required: true
|
||||
type: string
|
||||
CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
MODEL_CACHE_DIR:
|
||||
description: "Cache Dir Use"
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
secrets:
|
||||
github-token:
|
||||
required: true
|
||||
|
||||
jobs:
|
||||
check_cov_skip:
|
||||
uses: ./.github/workflows/check-bypass.yml
|
||||
secrets:
|
||||
github-token: ${{ secrets.github-token }}
|
||||
with:
|
||||
workflow-name: coverage
|
||||
|
||||
run_tests_with_coverage:
|
||||
runs-on: [self-hosted, GPU-h1z1-2Cards]
|
||||
timeout-minutes: 90
|
||||
needs: check_cov_skip
|
||||
if: needs.check_cov_skip.outputs.can-skip != 'true'
|
||||
outputs:
|
||||
diff_cov_file_url: ${{ steps.cov_upload.outputs.diff_cov_file_url }}
|
||||
unittest_failed_url: ${{ steps.cov_upload.outputs.unittest_failed_url }}
|
||||
diff_cov_result_json_url: ${{ steps.cov_upload.outputs.diff_cov_result_json_url }}
|
||||
steps:
|
||||
- name: Code Prepare
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
run: |
|
||||
set -x
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
docker pull ${docker_image}
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}*
|
||||
fi
|
||||
'
|
||||
|
||||
wget -q ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
rm -rf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git log -n 3 --oneline
|
||||
- name: Run FastDeploy Unit Tests and Coverage
|
||||
shell: bash
|
||||
env:
|
||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||
fd_wheel_url: ${{ inputs.FASTDEPLOY_WHEEL_URL }}
|
||||
CACHE_DIR: ${{ inputs.CACHE_DIR }}
|
||||
BASE_REF: ${{ github.event.pull_request.base.ref }}
|
||||
MODEL_CACHE_DIR: ${{ inputs.MODEL_CACHE_DIR }}
|
||||
IS_PR: ${{ github.event_name == 'pull_request' }}
|
||||
run: |
|
||||
if [[ "$IS_PR" == "true" ]]; then
|
||||
echo "Running on PR"
|
||||
else
|
||||
echo "Not a PR"
|
||||
fi
|
||||
runner_name="${{ runner.name }}"
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
DEVICES=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
DEVICE_PORT=$(echo "$DEVICES" | cut -d',' -f1)
|
||||
|
||||
FLASK_PORT=$((42068 + DEVICE_PORT * 100))
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
PIDS=$(lsof -t -i :$port || true)
|
||||
if [ -n "$PIDS" ]; then
|
||||
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
|
||||
echo "$PIDS" | xargs -r kill -9
|
||||
echo "Port $port cleared" | tee -a $LOG_FILE
|
||||
else
|
||||
echo "Port $port is free" | tee -a $LOG_FILE
|
||||
fi
|
||||
done
|
||||
|
||||
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
|
||||
|
||||
echo "========================================================="
|
||||
echo "Ensuring no stale container named ${runner_name} ..."
|
||||
if [ "$(docker ps -a -q -f name=${runner_name})" ]; then
|
||||
echo "Removing stale container: ${runner_name}"
|
||||
docker rm -f ${runner_name} || true
|
||||
fi
|
||||
|
||||
docker run --rm --net=host \
|
||||
--name ${runner_name} \
|
||||
--cap-add=SYS_PTRACE --shm-size=64G \
|
||||
-v $(pwd):/workspace -w /workspace \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-v "${MODEL_CACHE_DIR}:/ModelData:ro" \
|
||||
-e "MODEL_PATH=/ModelData" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
-e "fd_wheel_url=${fd_wheel_url}" \
|
||||
-e "BASE_REF=${BASE_REF}" \
|
||||
-e "IS_PR=${IS_PR}" \
|
||||
--gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c '
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install -r scripts/unittest_requirement.txt
|
||||
python -m pip install ${fd_wheel_url}
|
||||
rm -rf fastdeploy
|
||||
# coverage subprocess use
|
||||
python -m pip install ${fd_wheel_url} --no-deps --target=/workspace/FastDeploy
|
||||
export PYTHONPATH=/workspace/FastDeploy/
|
||||
if [ -d "tests/plugins" ]; then
|
||||
cd tests/plugins
|
||||
python setup.py install
|
||||
cd ../..
|
||||
else
|
||||
echo "Warning: tests/plugins directory not found, skipping setup.py install"
|
||||
fi
|
||||
export COVERAGE_FILE=/workspace/FastDeploy/coveragedata/.coverage
|
||||
export COVERAGE_RCFILE=/workspace/FastDeploy/scripts/.coveragerc
|
||||
TEST_EXIT_CODE=0
|
||||
bash scripts/coverage_run.sh || TEST_EXIT_CODE=8
|
||||
echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> exit_code.env
|
||||
coverage combine coveragedata/ || echo "No data to combine"
|
||||
coverage report
|
||||
coverage xml -o python_coverage_all.xml
|
||||
COVERAGE_EXIT_CODE=0
|
||||
if [[ "$IS_PR" == "true" ]]; then
|
||||
diff-cover python_coverage_all.xml --diff-file=diff.txt --fail-under=80 --json-report diff_coverage.json || COVERAGE_EXIT_CODE=9
|
||||
python scripts/generate_diff_coverage_xml.py diff.txt python_coverage_all.xml
|
||||
else
|
||||
echo "Not a PR, skipping diff-cover"
|
||||
fi
|
||||
echo "COVERAGE_EXIT_CODE=${COVERAGE_EXIT_CODE}" >> exit_code.env
|
||||
'
|
||||
if [ -f FastDeploy/exit_code.env ]; then
|
||||
cat FastDeploy/exit_code.env >> $GITHUB_ENV
|
||||
fi
|
||||
- name: Upload unit resule and diff coverage to bos
|
||||
id: cov_upload
|
||||
shell: bash
|
||||
run: |
|
||||
cd FastDeploy
|
||||
commit_id=${{ github.event.pull_request.head.sha }}
|
||||
pr_num=${{ github.event.pull_request.number }}
|
||||
target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id}/SM${compile_arch//,/_}
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py -O bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
diff_cov_file="diff_coverage.xml"
|
||||
if [ -f ${diff_cov_file} ];then
|
||||
python ${push_file} ${diff_cov_file} ${target_path}/CoverageData
|
||||
target_path_stripped="${target_path#paddle-github-action/}"
|
||||
DIFF_COV_FILE_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${diff_cov_file}
|
||||
echo "diff_cov_file_url=${DIFF_COV_FILE_URL}" >> $GITHUB_OUTPUT
|
||||
echo "diff_cov_file_url=${DIFF_COV_FILE_URL}" >> $GITHUB_ENV
|
||||
fi
|
||||
diff_cov_result_json="diff_coverage.json"
|
||||
if [ -f ${diff_cov_result_json} ];then
|
||||
python ${push_file} ${diff_cov_result_json} ${target_path}/CoverageData
|
||||
target_path_stripped="${target_path#paddle-github-action/}"
|
||||
DIFF_COV_JSON_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${diff_cov_result_json}
|
||||
echo "diff_cov_result_json_url=${DIFF_COV_JSON_URL}" >> $GITHUB_OUTPUT
|
||||
echo "diff_cov_result_json_url=${DIFF_COV_JSON_URL}" >> $GITHUB_ENV
|
||||
fi
|
||||
unittest_result="failed_tests.log"
|
||||
if [ -s ${unittest_result} ];then
|
||||
python ${push_file} ${unittest_result} ${target_path}/UnitTestResult
|
||||
target_path_stripped="${target_path#paddle-github-action/}"
|
||||
UNIT_TEST_RESULT_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/UnitTestResult/${unittest_result}
|
||||
echo "unittest_failed_url=${UNIT_TEST_RESULT_URL}" >> $GITHUB_OUTPUT
|
||||
echo "unittest_failed_url=${UNIT_TEST_RESULT_URL}" >> $GITHUB_ENV
|
||||
fi
|
||||
- name: Check Unit Test Success
|
||||
shell: bash
|
||||
run: |
|
||||
cd FastDeploy
|
||||
if [ "$TEST_EXIT_CODE" -eq 8 ]; then
|
||||
filename=$(basename "$unittest_failed_url")
|
||||
if [ -z "${unittest_failed_url}" ]; then
|
||||
echo "No diff unit failed file URL provided."
|
||||
else
|
||||
rm -rf "${filename}"
|
||||
wget -O ${filename} ${unittest_failed_url} || echo "Download unittest file failed, but continuing..."
|
||||
fi
|
||||
echo "Unit tests failed (exit code 8)"
|
||||
if [ -f "${filename}" ];then
|
||||
echo "Failed test cases:"
|
||||
cat "${filename}"
|
||||
fi
|
||||
exit "$TEST_EXIT_CODE"
|
||||
fi
|
||||
echo "All tests passed"
|
||||
|
||||
- name: Verify Code Coverage Threshold (80%)
|
||||
if: ${{ github.event_name == 'pull_request' }}
|
||||
shell: bash
|
||||
run: |
|
||||
cd FastDeploy
|
||||
if [ "$COVERAGE_EXIT_CODE" -eq 9 ]; then
|
||||
echo "Coverage generation failed (exit code 9)"
|
||||
filename=$(basename "$diff_cov_result_json_url")
|
||||
if [ -z "${diff_cov_result_json_url}" ]; then
|
||||
echo "No diff cov result file URL provided."
|
||||
else
|
||||
rm -rf "${filename}"
|
||||
wget -O ${filename} ${diff_cov_result_json_url} || echo "Download cov json file failed, but continuing..."
|
||||
fi
|
||||
if [ -f "${filename}" ];then
|
||||
echo "Failed test cases:"
|
||||
if command -v jq >/dev/null 2>&1; then
|
||||
jq . "${filename}"
|
||||
else
|
||||
cat "${filename}"
|
||||
fi
|
||||
fi
|
||||
exit "$COVERAGE_EXIT_CODE"
|
||||
fi
|
||||
echo "coverage passed"
|
||||
exit 0
|
||||
|
||||
diff_coverage_report:
|
||||
needs: run_tests_with_coverage
|
||||
if: always()
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||
steps:
|
||||
- name: coverage diff file download
|
||||
shell: bash
|
||||
env:
|
||||
diff_cov_file_url: ${{ needs.run_tests_with_coverage.outputs.diff_cov_file_url }}
|
||||
run: |
|
||||
wget ${fd_archive_url}
|
||||
tar -xf FastDeploy.tar.gz
|
||||
cd FastDeploy
|
||||
if [ -z "${diff_cov_file_url}" ]; then
|
||||
echo "No diff coverage file URL provided."
|
||||
exit 0
|
||||
fi
|
||||
wget "${diff_cov_file_url}" -O ./diff_coverage.xml || echo "Download cov file failed, but continuing..."
|
||||
- name: Upload diff coverage report
|
||||
if: ${{ needs.run_tests_with_coverage.outputs.diff_cov_file_url != null && needs.run_tests_with_coverage.outputs.diff_cov_file_url != '' }}
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
files: ./FastDeploy/diff_coverage.xml
|
||||
name: python diff coverage
|
||||
verbose: true
|
||||
disable_search: true
|
||||
commit_parent: false
|
||||
flags: diff
|
42
.github/workflows/approve.yml
vendored
42
.github/workflows/approve.yml
vendored
@@ -1,42 +0,0 @@
|
||||
name: Approval
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
jobs:
|
||||
Approval:
|
||||
name: Approval
|
||||
if: ${{ github.repository_owner == 'PaddlePaddle' }}
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
PR_ID: ${{ github.event.pull_request.number }}
|
||||
BRANCH: ${{ github.event.pull_request.base.ref }}
|
||||
steps:
|
||||
- name: Checkout base repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.base.ref }}
|
||||
fetch-depth: 1000
|
||||
|
||||
- name: Merge PR to test branch
|
||||
run: |
|
||||
git fetch origin pull/${PR_ID}/merge
|
||||
git checkout -b test FETCH_HEAD
|
||||
git log -n 3 --oneline
|
||||
git remote add upstream https://github.com/PaddlePaddle/FastDeploy.git
|
||||
git fetch upstream $BRANCH
|
||||
|
||||
- name: Setup python3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Run approval check script
|
||||
run: |
|
||||
bash scripts/check_approval.sh
|
248
.github/workflows/ce_job.yml
vendored
248
.github/workflows/ce_job.yml
vendored
@@ -1,248 +0,0 @@
|
||||
name: CE Compile Job
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
group: CE-Job-${{ github.ref }}-${{ github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
ce_job_pre_check:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
COMPILE_BRANCH: ${{ vars.COMPILE_BRANCH }}
|
||||
CE_COMPILE_SELECTION: ${{ vars.CE_COMPILE_SELECTION }}
|
||||
COMPILE_USE_PADDLE_WHL_URL_MAPPINGS: ${{ vars.COMPILE_USE_PADDLE_WHL_URL_MAPPINGS }}
|
||||
outputs:
|
||||
branch_match: ${{ steps.set_output.outputs.branch_match }}
|
||||
compile_use_paddle_whl_url: ${{ steps.set_output.outputs.compile_use_paddle_whl_url }}
|
||||
sm8689_match: ${{ steps.set_output.outputs.sm8689_match }}
|
||||
sm8090_match: ${{ steps.set_output.outputs.sm8090_match }}
|
||||
|
||||
steps:
|
||||
- name: Set Version
|
||||
id: set_output
|
||||
env:
|
||||
COMPILE_BRANCH: ${{ env.COMPILE_BRANCH }}
|
||||
CE_COMPILE_SELECTION: ${{ env.CE_COMPILE_SELECTION }}
|
||||
COMPILE_USE_PADDLE_WHL_URL_MAPPINGS: ${{ env.COMPILE_USE_PADDLE_WHL_URL_MAPPINGS }}
|
||||
GITHUB_REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
# 选择要触发编译任务的分支 done
|
||||
# 选择指定分支要编译的任务 8090或者8689
|
||||
# 指定分支编译要使用的Paddle的安装包,默认使用nightly最新的
|
||||
|
||||
IFS=',' read -ra BRANCHES <<< "$COMPILE_BRANCH"
|
||||
MATCH=false
|
||||
for b in "${BRANCHES[@]}"; do
|
||||
if [[ "$b" == "${GITHUB_REF_NAME}" ]]; then
|
||||
MATCH=true
|
||||
break
|
||||
fi
|
||||
done
|
||||
echo "branch_match=$MATCH" >> $GITHUB_OUTPUT
|
||||
|
||||
# 通过变量CE_COMPILE_SELECTION中的映射关系,决定分支是编译sm8090还是sm8689
|
||||
for pair in $(echo "$CE_COMPILE_SELECTION" | tr ';' ' '); do
|
||||
branch=$(echo "$pair" | cut -d',' -f1)
|
||||
compile_task_list=$(echo "$pair" | cut -d',' -f2)
|
||||
|
||||
if [[ "$branch" == "$GITHUB_REF_NAME" ]]; then
|
||||
|
||||
# 判断里面是否包含 sm8090 或 sm8689
|
||||
if [[ "$compile_task_list" == *"sm8090"* ]]; then
|
||||
echo "sm8090_match=true" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
if [[ "$compile_task_list" == *"sm8689"* ]]; then
|
||||
echo "sm8689_match=true" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
# 通过变量COMPILE_USE_PADDLE_WHL_URL_MAPPINGS中的映射关系,决定是否是安装指定版本的Paddle还是直接安装URL
|
||||
for pair in $(echo $COMPILE_USE_PADDLE_WHL_URL_MAPPINGS | tr ';' ' '); do
|
||||
branch=$(echo "$pair" | cut -d',' -f1)
|
||||
paddle_whl_url=$(echo "$pair" | cut -d',' -f2)
|
||||
if [[ "$branch" == "${{ github.ref_name }}" ]]; then
|
||||
FOUND_PADDLE_URL="$paddle_whl_url"
|
||||
echo "compile_use_paddle_whl_url=${FOUND_PADDLE_URL}" >> $GITHUB_OUTPUT
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
print_ce_job_pre_check_outputs:
|
||||
runs-on: ubuntu-latest
|
||||
needs: ce_job_pre_check
|
||||
steps:
|
||||
- name: Print outputs as JSON
|
||||
run: |
|
||||
echo '${{ toJSON(needs.ce_job_pre_check.outputs) }}'
|
||||
|
||||
|
||||
clone:
|
||||
environment: CodeSync
|
||||
name: FD-Clone-Linux
|
||||
runs-on: ubuntu-latest
|
||||
needs: ce_job_pre_check
|
||||
if: ${{ needs.ce_job_pre_check.outputs.branch_match == 'true' }}
|
||||
outputs:
|
||||
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
|
||||
steps:
|
||||
- name: Clone FastDeploy
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request'
|
||||
&& github.event.pull_request.base.ref
|
||||
|| github.ref_name }}
|
||||
submodules: 'recursive'
|
||||
fetch-depth: 1000
|
||||
|
||||
- name: Python Setup
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Code Info Show and Upload
|
||||
id: set_output
|
||||
env:
|
||||
AK: ${{ secrets.BOS_AK }}
|
||||
SK: ${{ secrets.BOS_SK }}
|
||||
run: |
|
||||
git config --unset http.https://github.com/.extraheader
|
||||
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
|
||||
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
|
||||
echo "Current HEAD Log:"
|
||||
git log --oneline -n 5
|
||||
ls
|
||||
cd ..
|
||||
tar -zcf FastDeploy.tar.gz FastDeploy
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
python ${push_file} FastDeploy.tar.gz ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
|
||||
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
|
||||
|
||||
resultshow:
|
||||
name: Show Code Archive Output
|
||||
needs: clone
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Print wheel path
|
||||
run: |
|
||||
echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}"
|
||||
|
||||
build_sm8090:
|
||||
name: BUILD_SM8090
|
||||
needs: [clone, ce_job_pre_check]
|
||||
if: ${{ needs.ce_job_pre_check.outputs.sm8090_match == 'true' }}
|
||||
uses: ./.github/workflows/_build_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "80,90"
|
||||
WITH_NIGHTLY_BUILD: OFF
|
||||
FD_VERSION: 0.0.0
|
||||
PADDLE_WHL_URL: ${{ needs.ce_job_pre_check.outputs.compile_use_paddle_whl_url }}
|
||||
|
||||
build_sm8689:
|
||||
name: BUILD_SM8689
|
||||
needs: [clone, ce_job_pre_check]
|
||||
if: ${{ needs.ce_job_pre_check.outputs.sm8689_match == 'true' }}
|
||||
uses: ./.github/workflows/_build_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "86,89"
|
||||
WITH_NIGHTLY_BUILD: OFF
|
||||
FD_VERSION: 0.0.0
|
||||
PADDLE_WHL_URL: ${{ needs.ce_job_pre_check.outputs.compile_use_paddle_whl_url }}
|
||||
|
||||
ce_upload_sm8090:
|
||||
environment: CodeSync
|
||||
name: CE_UPLOAD
|
||||
needs: build_sm8090
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AK: ${{ secrets.BOS_AK }}
|
||||
SK: ${{ secrets.BOS_SK }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
COMPILE_ARCH: "80,90"
|
||||
steps:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Wheel Info Show and Upload
|
||||
run: |
|
||||
echo "The wheel is located at: ${{ needs.build_sm8090.outputs.wheel_path }}"
|
||||
wget -q --no-check-certificate ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
filename=$(basename ${{ needs.build_sm8090.outputs.wheel_path }})
|
||||
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/${commit_id}
|
||||
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
python ${push_file} ${filename} ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${filename}
|
||||
|
||||
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
|
||||
python ${push_file} ${filename} ${target_path_latest}
|
||||
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
|
||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${filename}
|
||||
echo "commit wheel url is ${WHEEL_PATH}"
|
||||
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
|
||||
|
||||
ce_upload_sm8689:
|
||||
environment: CodeSync
|
||||
name: CE_UPLOAD
|
||||
needs: build_sm8689
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AK: ${{ secrets.BOS_AK }}
|
||||
SK: ${{ secrets.BOS_SK }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8689.outputs.wheel_path }}
|
||||
COMPILE_ARCH: "86,89"
|
||||
steps:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Wheel Info Show and Upload
|
||||
run: |
|
||||
echo "The wheel is located at: ${{ needs.build_sm8689.outputs.wheel_path }}"
|
||||
wget -q --no-check-certificate ${{ needs.build_sm8689.outputs.wheel_path }}
|
||||
filename=$(basename ${{ needs.build_sm8689.outputs.wheel_path }})
|
||||
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/${commit_id}
|
||||
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
python ${push_file} ${filename} ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${filename}
|
||||
|
||||
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
|
||||
python ${push_file} ${filename} ${target_path_latest}
|
||||
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
|
||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${filename}
|
||||
echo "commit wheel url is ${WHEEL_PATH}"
|
||||
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
|
51
.github/workflows/check-bypass.yml
vendored
51
.github/workflows/check-bypass.yml
vendored
@@ -1,51 +0,0 @@
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
workflow-name:
|
||||
required: true
|
||||
type: string
|
||||
secrets:
|
||||
github-token:
|
||||
required: true
|
||||
outputs:
|
||||
can-skip:
|
||||
description: "Whether the workflow can be skipped."
|
||||
value: ${{ jobs.check-bypass.outputs.can-skip }}
|
||||
|
||||
jobs:
|
||||
check-bypass:
|
||||
name: Check bypass
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
env:
|
||||
CI_TEAM_MEMBERS: '["yuanlehome","YuanRisheng","Jiang-Jia-Jun","DDDivano","XieYunshen"]'
|
||||
outputs:
|
||||
can-skip: ${{ steps.check-bypass.outputs.can-skip }}
|
||||
steps:
|
||||
- name: Cleanup
|
||||
run: |
|
||||
rm -rf * .[^.]*
|
||||
|
||||
- id: check-bypass
|
||||
name: Check Bypass
|
||||
uses: PFCCLab/ci-bypass@v1
|
||||
with:
|
||||
github-token: ${{ secrets.github-token }}
|
||||
non-pull-request-event-strategy: 'never-skipped'
|
||||
type: 'composite'
|
||||
composite-rule: |
|
||||
{
|
||||
"any": [
|
||||
{
|
||||
"type": "labeled",
|
||||
"label": ["skip-ci: ${{ inputs.workflow-name }}", "skip-ci: all"],
|
||||
"username": ${{ env.CI_TEAM_MEMBERS }}
|
||||
},
|
||||
{
|
||||
"type": "commented",
|
||||
"comment-pattern": [".*/skip-ci ${{ inputs.workflow-name }}.*", ".*/skip-ci all.*"],
|
||||
"username": ${{ env.CI_TEAM_MEMBERS }}
|
||||
}
|
||||
]
|
||||
}
|
109
.github/workflows/ci.yml
vendored
Normal file
109
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,109 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.event.pull_request.number }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: [self-hosted, GPU-L20-4Card]
|
||||
steps:
|
||||
- name: Print current runner name
|
||||
run: |
|
||||
echo "Current runner name: ${{ runner.name }}"
|
||||
# Because the system version is lower than 2.23, the checkout cannot be used.
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@v4
|
||||
|
||||
- name: Code Checkout
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126
|
||||
run: |
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
-e "BASE_BRANCH=${BASE_BRANCH}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}
|
||||
fi
|
||||
'
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
|
||||
cd FastDeploy
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
git merge pr/${{ github.event.pull_request.number }}
|
||||
git log -n 3 --oneline
|
||||
else
|
||||
git checkout ${{ github.sha }}
|
||||
git log -n 3 --oneline
|
||||
fi
|
||||
|
||||
- name: Run CI unittest
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:fastdeploy-ciuse-cuda126
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
last_char="${runner_name: -1}"
|
||||
|
||||
if [ "${last_char}" = "1" ]; then
|
||||
gpu_id=2
|
||||
DEVICES="2,3"
|
||||
else
|
||||
gpu_id=0
|
||||
DEVICES="0,1"
|
||||
fi
|
||||
|
||||
FLASK_PORT=$((41068 + gpu_id * 100))
|
||||
FD_API_PORT=$((41088 + gpu_id * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((41058 + gpu_id * 100))
|
||||
FD_METRICS_PORT=$((41078 + gpu_id * 100))
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
echo "==== PORT CLEAN BEFORE TASK RUN ====" | tee -a $LOG_FILE
|
||||
|
||||
for port in "${PORTS[@]}"; do
|
||||
PIDS=$(lsof -t -i :$port || true)
|
||||
if [ -n "$PIDS" ]; then
|
||||
echo "Port $port is occupied by PID(s): $PIDS" | tee -a $LOG_FILE
|
||||
echo "$PIDS" | xargs -r kill -9
|
||||
echo "Port $port cleared" | tee -a $LOG_FILE
|
||||
else
|
||||
echo "Port $port is free" | tee -a $LOG_FILE
|
||||
fi
|
||||
done
|
||||
echo "==== PORT CLEAN COMPLETE ====" | tee -a $LOG_FILE
|
||||
|
||||
PARENT_DIR=$(dirname "$WORKSPACE")
|
||||
echo "PARENT_DIR:$PARENT_DIR"
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-v "/ssd4/GithubActions/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "/ssd4/GithubActions/ModelData:/ModelData:ro" \
|
||||
-v "/ssd4/GithubActions/CacheDir:/root/.cache" \
|
||||
-v "/ssd4/GithubActions/ConfigDir:/root/.config" \
|
||||
-e "MODEL_PATH=/ModelData" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -c "
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
bash scripts/run_ci.sh
|
||||
"
|
98
.github/workflows/ci_gcu.yml
vendored
98
.github/workflows/ci_gcu.yml
vendored
@@ -1,98 +0,0 @@
|
||||
name: CI_GCU
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.event.pull_request.number }}-gcu-ci
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
CI_GCU:
|
||||
runs-on:
|
||||
group: GCU
|
||||
steps:
|
||||
- name: Print current runner name
|
||||
run: |
|
||||
echo "Current runner name: ${{ runner.name }}"
|
||||
|
||||
- name: Code Checkout
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84
|
||||
run: |
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace \
|
||||
-v ${{ github.workspace }}/../../..:${{ github.workspace }}/../../.. \
|
||||
-w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
-e "BASE_BRANCH=${BASE_BRANCH}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}
|
||||
fi
|
||||
'
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
source ${{ github.workspace }}/../../../proxy
|
||||
git clone ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
|
||||
cd FastDeploy
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
git merge pr/${{ github.event.pull_request.number }}
|
||||
git log -n 3 --oneline
|
||||
else
|
||||
git checkout ${{ github.sha }}
|
||||
git log -n 3 --oneline
|
||||
fi
|
||||
echo "Copy models..."
|
||||
sudo mkdir -p ci_models && sudo cp -r /work/deps/ERNIE-4.5-21B-A3B-Paddle ci_models
|
||||
echo "Copy models done."
|
||||
|
||||
- name: Run CI unittest
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
last_char="${runner_name: -1}"
|
||||
|
||||
if [[ "$last_char" =~ [0-3] ]]; then
|
||||
gcu_id="$last_char"
|
||||
else
|
||||
gcu_id="0"
|
||||
fi
|
||||
FD_API_PORT=$((9180 + gcu_id * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((9150 + gcu_id * 100))
|
||||
FD_METRICS_PORT=$((9170 + gcu_id * 100))
|
||||
|
||||
PARENT_DIR=$(dirname "$WORKSPACE")
|
||||
echo "PARENT_DIR:$PARENT_DIR"
|
||||
echo "Install drivers..."
|
||||
cd /work/deps
|
||||
sudo bash TopsRider_i3x_*_deb_amd64.run --driver --no-auto-load -y
|
||||
cd -
|
||||
echo "Create docker..."
|
||||
docker run --rm --network=host --ipc=host --privileged \
|
||||
-v $(pwd):/workspace \
|
||||
-v /home:/home \
|
||||
-v /work:/work \
|
||||
-w /workspace \
|
||||
-e "MODEL_PATH=./ci_models" \
|
||||
-e "http_proxy=$(git config --global --get http.proxy)" \
|
||||
-e "https_proxy=$(git config --global --get https.proxy)" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
${docker_image} /bin/bash -c "
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
bash scripts/run_ci_gcu.sh
|
||||
"
|
3
.github/workflows/ci_iluvatar.yml
vendored
3
.github/workflows/ci_iluvatar.yml
vendored
@@ -11,8 +11,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
CI_ILUVATAR:
|
||||
runs-on:
|
||||
group: IXUCA
|
||||
runs-on: [self-hosted, IXUCA]
|
||||
steps:
|
||||
- name: Print current runner name
|
||||
run: |
|
||||
|
174
.github/workflows/ci_image_update.yml
vendored
174
.github/workflows/ci_image_update.yml
vendored
@@ -1,174 +0,0 @@
|
||||
name: CI Images Build
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 18 * * *' # 2:00 AM China Standard Time (UTC+8)
|
||||
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
group: CI-Images-Build-${{ github.ref }}-${{ github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
|
||||
jobs:
|
||||
clone:
|
||||
environment: CodeSync
|
||||
name: FD-Clone-Linux
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
|
||||
steps:
|
||||
- name: Clone FastDeploy
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name }}
|
||||
submodules: 'recursive'
|
||||
fetch-depth: 1000
|
||||
|
||||
- name: Python Setup
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Code Info Show and Upload
|
||||
id: set_output
|
||||
env:
|
||||
AK: ${{ secrets.BOS_AK }}
|
||||
SK: ${{ secrets.BOS_SK }}
|
||||
run: |
|
||||
git config --unset http.https://github.com/.extraheader
|
||||
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
|
||||
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
|
||||
echo "Current HEAD Log:"
|
||||
git log --oneline -n 5
|
||||
ls
|
||||
cd ..
|
||||
tar -zcf FastDeploy.tar.gz FastDeploy
|
||||
if [[ "${{ github.ref_type }}" == "tag" ]]; then
|
||||
commit_id=${{ github.sha }}
|
||||
tag_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/TAG/FastDeploy/${tag_name}/${commit_id}
|
||||
else
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}
|
||||
fi
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
python ${push_file} FastDeploy.tar.gz ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
|
||||
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
|
||||
|
||||
resultshow:
|
||||
name: Show Code Archive Output
|
||||
needs: clone
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Print wheel path
|
||||
run: |
|
||||
echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}"
|
||||
|
||||
ci_image_build:
|
||||
name: CI Images Build
|
||||
needs: clone
|
||||
uses: ./.github/workflows/_ci_image_build.yml
|
||||
with:
|
||||
CI_DOCKER_IMAGE_NAME: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate-precheck
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
|
||||
|
||||
build_sm8090:
|
||||
name: BUILD_SM8090
|
||||
needs: [clone, ci_image_build]
|
||||
uses: ./.github/workflows/_build_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "90"
|
||||
WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }}
|
||||
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
|
||||
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
|
||||
PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }}
|
||||
|
||||
|
||||
unittest_coverage:
|
||||
name: Run FastDeploy Unit Tests and Coverage
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_unit_test_coverage.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
logprob_test:
|
||||
name: Run FastDeploy LogProb Tests
|
||||
needs: [build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_logprob_test_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
pre_ce_test:
|
||||
name: Extracted partial CE model tasks to run in CI.
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_pre_ce_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
base_test:
|
||||
name: Run Base Tests
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_base_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
accuracy_test:
|
||||
name: Run Accuracy Tests
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_accuracy_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
stable_test:
|
||||
name: Run Stable Tests
|
||||
needs: [clone,build_sm8090,ci_image_build]
|
||||
uses: ./.github/workflows/_stable_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
|
||||
publish_pre_check:
|
||||
name: Publish Docker Images Pre Check
|
||||
needs: [ci_image_build, unittest_coverage,logprob_test,pre_ce_test,base_test,accuracy_test,stable_test]
|
||||
runs-on: [self-hosted, Docker-Build]
|
||||
steps:
|
||||
- name: Images Uploading
|
||||
env:
|
||||
images_name: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
||||
ci_image_name: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate"
|
||||
run: |
|
||||
echo "images_name=${images_name}"
|
||||
docker images ${ci_image_name}
|
||||
docker tag ${images_name} ${ci_image_name}
|
||||
docker push ${ci_image_name}
|
5
.github/workflows/ci_xpu.yml
vendored
5
.github/workflows/ci_xpu.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
|
||||
- name: Code Checkout
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.1.0
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
|
||||
run: |
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
|
||||
- name: Run CI unittest
|
||||
env:
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.1.0
|
||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.0.0
|
||||
run: |
|
||||
runner_name="${{ runner.name }}"
|
||||
last_char="${runner_name: -1}"
|
||||
@@ -77,7 +77,6 @@ jobs:
|
||||
-e "MODEL_PATH=/ssd3/model" \
|
||||
-e "http_proxy=$(git config --global --get http.proxy)" \
|
||||
-e "https_proxy=$(git config --global --get https.proxy)" \
|
||||
-e "no_proxy=bcebos.com,mirrors.tuna.tsinghua.edu.cn,127.0.0.1,localhost" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
|
2
.github/workflows/gh-pages.yml
vendored
2
.github/workflows/gh-pages.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.x
|
||||
- run: pip install mkdocs-material mkdocs-get-deps mkdocs-material-extensions mkdocs-multilang mkdocs-static-i18n
|
||||
- run: pip install mkdocs-material mkdocs-get-deps mkdocs-material-extensions mkdocs-multilang
|
||||
- name: Deploy to GitHub Pages
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
64
.github/workflows/pr_build_and_test.yml
vendored
64
.github/workflows/pr_build_and_test.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
needs: clone
|
||||
uses: ./.github/workflows/_build_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "90"
|
||||
WITH_NIGHTLY_BUILD: "OFF"
|
||||
@@ -33,65 +33,3 @@ jobs:
|
||||
- name: Print wheel path
|
||||
run: |
|
||||
echo "The built wheel is located at: ${{ needs.build.outputs.wheel_path }}"
|
||||
|
||||
unittest_coverage:
|
||||
name: Run FastDeploy Unit Tests and Coverage
|
||||
needs: [clone,build]
|
||||
uses: ./.github/workflows/_unit_test_coverage.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
logprob_test:
|
||||
name: Run FastDeploy LogProb Tests
|
||||
needs: [build]
|
||||
uses: ./.github/workflows/_logprob_test_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
pre_ce_test:
|
||||
name: Extracted partial CE model tasks to run in CI.
|
||||
needs: [clone,build]
|
||||
uses: ./.github/workflows/_pre_ce_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
base_test:
|
||||
name: Run Base Tests
|
||||
needs: [clone,build]
|
||||
uses: ./.github/workflows/_base_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
accuracy_test:
|
||||
name: Run Accuracy Tests
|
||||
needs: [clone,build]
|
||||
uses: ./.github/workflows/_accuracy_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
stable_test:
|
||||
name: Run Stable Tests
|
||||
needs: [clone,build]
|
||||
uses: ./.github/workflows/_stable_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
331
.github/workflows/publish_job.yml
vendored
331
.github/workflows/publish_job.yml
vendored
@@ -1,331 +0,0 @@
|
||||
name: Publish Job
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 18 * * *' # 2:00 AM China Standard Time (UTC+8)
|
||||
push:
|
||||
# branches:
|
||||
# - develop
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
group: Publish-Job-${{ github.ref }}-${{ github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
|
||||
jobs:
|
||||
publish_pre_check:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.repository.fork == false &&
|
||||
(
|
||||
(github.event_name == 'schedule' && github.ref_name == 'develop') ||
|
||||
(github.event_name == 'push' && github.ref_type == 'tag') ||
|
||||
((github.event_name == 'workflow_dispatch') &&
|
||||
(github.ref_name == 'develop' || github.ref_type == 'tag'))
|
||||
)
|
||||
env:
|
||||
TAG_VERSION_MAPPINGS: ${{ vars.TAG_VERSION_MAPPINGS }}
|
||||
FD_VERSION_DEV: ${{ vars.FD_VERSION_DEV }}
|
||||
COMPILE_USE_PADDLE_WHL_URL_MAPPINGS: ${{ vars.COMPILE_USE_PADDLE_WHL_URL_MAPPINGS }}
|
||||
outputs:
|
||||
compile_use_paddle_version: ${{ steps.set_output.outputs.compile_use_paddle_version }}
|
||||
compile_continue: ${{ steps.set_output.outputs.compile_continue }}
|
||||
fd_version: ${{ steps.set_output.outputs.fd_version }}
|
||||
with_nightly_build: ${{ steps.set_output.outputs.with_nightly_build }}
|
||||
compile_use_paddle_whl_url: ${{ steps.set_output.outputs.compile_use_paddle_whl_url }}
|
||||
|
||||
steps:
|
||||
- name: Get tag version
|
||||
if: github.ref_type == 'tag'
|
||||
run: |
|
||||
TAG_NAME="${GITHUB_REF##*/}" # 提取 tag 名称,比如 v2.1.0
|
||||
TAG_VERSION="${TAG_NAME#v}" # 去掉前缀 v
|
||||
echo "FD_VERSION=$TAG_VERSION" >> $GITHUB_ENV
|
||||
|
||||
- name: Check FD version to Paddle version mapping
|
||||
if: github.ref_type == 'tag'
|
||||
env:
|
||||
TARGET_FD: ${{ env.FD_VERSION }}
|
||||
run: |
|
||||
FOUND_PADDLE=""
|
||||
# 遍历映射
|
||||
for pair in $(echo $TAG_VERSION_MAPPINGS | tr ';' ' '); do
|
||||
fd=$(echo "$pair" | cut -d',' -f1)
|
||||
paddle=$(echo "$pair" | cut -d',' -f2)
|
||||
if [[ "$fd" == "$TARGET_FD" ]]; then
|
||||
FOUND_PADDLE="$paddle"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if [[ -z "$FOUND_PADDLE" ]]; then
|
||||
echo "No Paddle version found for FD $TARGET_FD"
|
||||
else
|
||||
echo "FD $TARGET_FD maps to Paddle $FOUND_PADDLE"
|
||||
echo "PADDLE_VERSION=$FOUND_PADDLE" >> $GITHUB_ENV
|
||||
fi
|
||||
- name: Set Version
|
||||
id: set_output
|
||||
env:
|
||||
PADDLE_VERSION: ${{ env.PADDLE_VERSION }}
|
||||
FD_VERSION: ${{ env.FD_VERSION }}
|
||||
run: |
|
||||
if [[ "${{ github.ref_type }}" == "tag" ]]; then
|
||||
if [[ -z "$PADDLE_VERSION" ]]; then
|
||||
compile_continue=false
|
||||
else
|
||||
compile_use_paddle_version=$PADDLE_VERSION
|
||||
compile_continue=true
|
||||
fi
|
||||
fd_version=$FD_VERSION
|
||||
fi
|
||||
if [[ "${{ github.ref_name }}" == "develop" ]];then
|
||||
compile_continue=true
|
||||
compile_use_paddle_version=""
|
||||
fd_version=${FD_VERSION_DEV}
|
||||
with_nightly_build=ON
|
||||
fi
|
||||
# Todo
|
||||
# 通过变量COMPILE_USE_PADDLE_WHL_URL_MAPPINGS中的映射关系,决定是否是安装指定版本的Paddle还是直接安装URL
|
||||
for pair in $(echo $COMPILE_USE_PADDLE_WHL_URL_MAPPINGS | tr ';' ' '); do
|
||||
branch=$(echo "$pair" | cut -d',' -f1)
|
||||
paddle_whl_url=$(echo "$pair" | cut -d',' -f2)
|
||||
if [[ "$branch" == "${{ github.ref_name }}" ]]; then
|
||||
FOUND_PADDLE_URL="$paddle_whl_url"
|
||||
echo "compile_use_paddle_whl_url=${FOUND_PADDLE_URL}" >> $GITHUB_OUTPUT
|
||||
compile_continue=true
|
||||
break
|
||||
fi
|
||||
done
|
||||
echo "compile_continue=${compile_continue}" >> $GITHUB_OUTPUT
|
||||
echo "compile_use_paddle_version=${compile_use_paddle_version}" >> $GITHUB_OUTPUT
|
||||
echo "fd_version=${fd_version}" >> $GITHUB_OUTPUT
|
||||
echo "with_nightly_build=${with_nightly_build:-OFF}" >> $GITHUB_OUTPUT
|
||||
|
||||
print_publish_pre_check_outputs:
|
||||
runs-on: ubuntu-latest
|
||||
needs: publish_pre_check
|
||||
steps:
|
||||
- name: Print outputs as JSON
|
||||
run: |
|
||||
echo '${{ toJSON(needs.publish_pre_check.outputs) }}'
|
||||
|
||||
clone:
|
||||
environment: CodeSync
|
||||
name: FD-Clone-Linux
|
||||
runs-on: ubuntu-latest
|
||||
needs: publish_pre_check
|
||||
if: ${{ needs.publish_pre_check.outputs.compile_continue == 'true' }}
|
||||
outputs:
|
||||
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
|
||||
steps:
|
||||
- name: Clone FastDeploy
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.ref_name }}
|
||||
submodules: 'recursive'
|
||||
fetch-depth: 1000
|
||||
|
||||
- name: Python Setup
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Code Info Show and Upload
|
||||
id: set_output
|
||||
env:
|
||||
AK: ${{ secrets.BOS_AK }}
|
||||
SK: ${{ secrets.BOS_SK }}
|
||||
run: |
|
||||
git config --unset http.https://github.com/.extraheader
|
||||
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
|
||||
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
|
||||
echo "Current HEAD Log:"
|
||||
git log --oneline -n 5
|
||||
ls
|
||||
cd ..
|
||||
tar -zcf FastDeploy.tar.gz FastDeploy
|
||||
if [[ "${{ github.ref_type }}" == "tag" ]]; then
|
||||
commit_id=${{ github.sha }}
|
||||
tag_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/TAG/FastDeploy/${tag_name}/${commit_id}
|
||||
else
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}
|
||||
fi
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
python ${push_file} FastDeploy.tar.gz ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
|
||||
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
|
||||
|
||||
resultshow:
|
||||
name: Show Code Archive Output
|
||||
needs: clone
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Print wheel path
|
||||
run: |
|
||||
echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}"
|
||||
|
||||
build_sm8090:
|
||||
name: BUILD_SM8090
|
||||
needs: [clone, publish_pre_check]
|
||||
uses: ./.github/workflows/_build_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "80,90"
|
||||
WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }}
|
||||
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
|
||||
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
|
||||
PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }}
|
||||
|
||||
build_sm8689:
|
||||
name: BUILD_SM8689
|
||||
needs: [clone, publish_pre_check]
|
||||
uses: ./.github/workflows/_build_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
COMPILE_ARCH: "86,89"
|
||||
WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }}
|
||||
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
|
||||
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
|
||||
PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }}
|
||||
|
||||
paddle_pypi_upload_sm8090:
|
||||
environment: PaddleSourceUpload
|
||||
name: PADDLE_PYPI_UPLOAD_8090
|
||||
needs: build_sm8090
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AK: ${{ secrets.BOS_AK }}
|
||||
SK: ${{ secrets.BOS_SK }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
COMPILE_ARCH: "80,90"
|
||||
steps:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Wheel Info Show and Upload
|
||||
if: github.ref_name == 'develop' || github.ref_type == 'tag'
|
||||
run: |
|
||||
echo "The wheel is located at: ${FASTDEPLOY_WHEEL_URL}"
|
||||
wget -q --no-check-certificate ${FASTDEPLOY_WHEEL_URL}
|
||||
filename=$(basename ${FASTDEPLOY_WHEEL_URL})
|
||||
if [[ "${{ github.ref_name }}" == "develop" ]];then
|
||||
target_path=paddle-whl/nightly/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
|
||||
elif [[ "${{ github.ref_type }}" == "tag" ]]; then
|
||||
target_path=paddle-whl/stable/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
|
||||
else
|
||||
echo "Not develop or tag, do nothing"
|
||||
fi
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
python ${push_file} ${filename} ${target_path}
|
||||
|
||||
paddle_pypi_upload_sm8689:
|
||||
environment: PaddleSourceUpload
|
||||
name: PADDLE_PYPI_UPLOAD_8689
|
||||
needs: build_sm8689
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
AK: ${{ secrets.BOS_AK }}
|
||||
SK: ${{ secrets.BOS_SK }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8689.outputs.wheel_path }}
|
||||
COMPILE_ARCH: "86,89"
|
||||
steps:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Wheel Info Show and Upload
|
||||
if: github.ref_name == 'develop' || github.ref_type == 'tag'
|
||||
run: |
|
||||
echo "The wheel is located at: ${FASTDEPLOY_WHEEL_URL}"
|
||||
wget -q --no-check-certificate ${FASTDEPLOY_WHEEL_URL}
|
||||
filename=$(basename ${FASTDEPLOY_WHEEL_URL})
|
||||
if [[ "${{ github.ref_name }}" == "develop" ]];then
|
||||
target_path=paddle-whl/nightly/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
|
||||
elif [[ "${{ github.ref_type }}" == "tag" ]]; then
|
||||
target_path=paddle-whl/stable/fastdeploy-gpu-${COMPILE_ARCH//,/_}/fastdeploy-gpu
|
||||
else
|
||||
echo "Not develop or tag, do nothing"
|
||||
fi
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
python -m pip install bce-python-sdk==0.9.29
|
||||
ls
|
||||
python ${push_file} ${filename} ${target_path}
|
||||
|
||||
unittest_coverage:
|
||||
name: Run FastDeploy Unit Tests and Coverage
|
||||
needs: [clone,build_sm8090]
|
||||
uses: ./.github/workflows/_unit_test_coverage.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
logprob_test:
|
||||
name: Run FastDeploy LogProb Tests
|
||||
needs: [build_sm8090]
|
||||
uses: ./.github/workflows/_logprob_test_linux.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
pre_ce_test:
|
||||
name: Extracted partial CE model tasks to run in CI.
|
||||
needs: [clone,build_sm8090]
|
||||
uses: ./.github/workflows/_pre_ce_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
base_test:
|
||||
name: Run Base Tests
|
||||
needs: [clone,build_sm8090]
|
||||
uses: ./.github/workflows/_base_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
accuracy_test:
|
||||
name: Run Accuracy Tests
|
||||
needs: [clone,build_sm8090]
|
||||
uses: ./.github/workflows/_accuracy_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||
|
||||
stable_test:
|
||||
name: Run Stable Tests
|
||||
needs: [clone,build_sm8090]
|
||||
uses: ./.github/workflows/_stable_test.yml
|
||||
with:
|
||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
14
.gitignore
vendored
14
.gitignore
vendored
@@ -121,7 +121,7 @@ dmypy.json
|
||||
FETCH_HEAD
|
||||
|
||||
#log
|
||||
log/
|
||||
log*/
|
||||
|
||||
checkpoints/
|
||||
checkpoints_origin/
|
||||
@@ -156,12 +156,6 @@ nohup.out
|
||||
custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cutlass
|
||||
custom_ops/gpu_ops/fp8_deep_gemm/deep_gemm/include/cute
|
||||
|
||||
#marlin_kernel
|
||||
custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/kernel_*.cu
|
||||
|
||||
#machete_kernel
|
||||
custom_ops/gpu_ops/machete/generated
|
||||
|
||||
# buff
|
||||
custom_ops/tmp*
|
||||
|
||||
@@ -170,9 +164,3 @@ build
|
||||
.ccls-cache
|
||||
|
||||
third_party
|
||||
|
||||
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_*.cu
|
||||
custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_template.h
|
||||
|
||||
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_*.cu
|
||||
custom_ops/gpu_ops/wfp8afp8_sparse_gemm/wfp8Afp8_sparse_gemm_template.h
|
||||
|
10
.gitmodules
vendored
10
.gitmodules
vendored
@@ -1,10 +0,0 @@
|
||||
[submodule "custom_ops/third_party/DeepGEMM"]
|
||||
path = custom_ops/third_party/DeepGEMM
|
||||
url = https://github.com/deepseek-ai/DeepGEMM.git
|
||||
ignore = all
|
||||
[submodule "custom_ops/third_party/cutlass"]
|
||||
path = custom_ops/third_party/cutlass
|
||||
url = https://github.com/NVIDIA/cutlass.git
|
||||
[submodule "custom_ops/third_party/nlohmann_json"]
|
||||
path = custom_ops/third_party/nlohmann_json
|
||||
url = https://github.com/nlohmann/json.git
|
25
README.md
25
README.md
@@ -1,4 +1,3 @@
|
||||
English | [简体中文](README_CN.md)
|
||||
<p align="center">
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/releases"><img src="https://github.com/user-attachments/assets/42b0039f-39e3-4279-afda-6d1865dfbffb" width="500"></a>
|
||||
</p>
|
||||
@@ -23,12 +22,11 @@ English | [简体中文](README_CN.md)
|
||||
</p>
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
# FastDeploy : Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle
|
||||
# FastDeploy 2.0: Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle
|
||||
|
||||
## News
|
||||
**[2025-09] 🔥 FastDeploy v2.2 is newly released!** It now offers compatibility with models in the HuggingFace ecosystem, has further optimized performance, and newly adds support for [baidu/ERNIE-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking)!
|
||||
|
||||
**[2025-08] 🔥 Released FastDeploy v2.1:** A brand-new KV Cache scheduling strategy has been introduced, and expanded support for PD separation and CUDA Graph across more models. Enhanced hardware support has been added for platforms like Kunlun and Hygon, along with comprehensive optimizations to improve the performance of both the service and inference engine.
|
||||
**[2025-07] 《FastDeploy2.0推理部署实测》专题活动已上线!** 完成文心4.5系列开源模型的推理部署等任务,即可获得骨瓷马克杯等FastDeploy2.0官方周边及丰富奖金!🎁 欢迎大家体验反馈~ 📌[报名地址](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[活动详情](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
|
||||
|
||||
**[2025-07] The FastDeploy 2.0 Inference Deployment Challenge is now live!** Complete the inference deployment task for the ERNIE 4.5 series open-source models to win official FastDeploy 2.0 merch and generous prizes! 🎁 You're welcome to try it out and share your feedback! 📌[Sign up here](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[Event details](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
|
||||
|
||||
@@ -52,16 +50,14 @@ English | [简体中文](README_CN.md)
|
||||
|
||||
## Installation
|
||||
|
||||
FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**, **Iluvatar GPUs**, **Enflame GCUs**, **Hygon DCUs** and other hardware. For detailed installation instructions:
|
||||
FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**, **Iluvatar GPUs**, **Enflame GCUs**, and other hardware. For detailed installation instructions:
|
||||
|
||||
- [NVIDIA GPU](./docs/get_started/installation/nvidia_gpu.md)
|
||||
- [Kunlunxin XPU](./docs/get_started/installation/kunlunxin_xpu.md)
|
||||
- [Iluvatar GPU](./docs/get_started/installation/iluvatar_gpu.md)
|
||||
- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md)
|
||||
- [Hygon DCU](./docs/get_started/installation/hygon_dcu.md)
|
||||
- [MetaX GPU](./docs/get_started/installation/metax_gpu.md)
|
||||
|
||||
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU are currently under development and testing. Stay tuned for updates!
|
||||
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU, Hygon DCU, and MetaX GPU are currently under development and testing. Stay tuned for updates!
|
||||
|
||||
## Get Started
|
||||
|
||||
@@ -71,12 +67,19 @@ Learn how to use FastDeploy through our documentation:
|
||||
- [ERNIE-4.5-VL Multimodal Model Deployment](./docs/get_started/ernie-4.5-vl.md)
|
||||
- [Offline Inference Development](./docs/offline_inference.md)
|
||||
- [Online Service Deployment](./docs/online_serving/README.md)
|
||||
- [Best Practices](./docs/best_practices/README.md)
|
||||
- [Full Supported Models List](./docs/supported_models.md)
|
||||
|
||||
## Supported Models
|
||||
|
||||
Learn how to download models, enable using the torch format, and more:
|
||||
- [Full Supported Models List](./docs/supported_models.md)
|
||||
| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length |
|
||||
|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- |
|
||||
|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅(WINT4)| WIP |128K |
|
||||
|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|✅(WINT4)| WIP | 128K |
|
||||
|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K |
|
||||
|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K |
|
||||
|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | WIP | ✅|128K |
|
||||
|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | WIP | ✅|128K |
|
||||
|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ❌ | ✅ | ✅ | ❌ | ✅| 128K |
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
|
89
README_CN.md
89
README_CN.md
@@ -1,89 +0,0 @@
|
||||
[English](README.md) | 简体中文
|
||||
<p align="center">
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/releases"><img src="https://github.com/user-attachments/assets/42b0039f-39e3-4279-afda-6d1865dfbffb" width="500"></a>
|
||||
</p>
|
||||
<p align="center">
|
||||
<a href=""><img src="https://img.shields.io/badge/python-3.10-aff.svg"></a>
|
||||
<a href=""><img src="https://img.shields.io/badge/os-linux-pink.svg"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/graphs/contributors"><img src="https://img.shields.io/github/contributors/PaddlePaddle/FastDeploy?color=9ea"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/commits"><img src="https://img.shields.io/github/commit-activity/m/PaddlePaddle/FastDeploy?color=3af"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/FastDeploy?color=9cc"></a>
|
||||
<a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a>
|
||||
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/4046" target="_blank"><img src="https://trendshift.io/api/badge/repositories/4046" alt="PaddlePaddle%2FFastDeploy | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></br>
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/zh/get_started/installation/nvidia_gpu/"><b> 安装指导 </b></a>
|
||||
|
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/zh/get_started/quick_start"><b> 快速入门 </b></a>
|
||||
|
|
||||
<a href="https://paddlepaddle.github.io/FastDeploy/zh/supported_models/"><b> 支持模型列表 </b></a>
|
||||
|
||||
</p>
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
# FastDeploy :基于飞桨的大语言模型与视觉语言模型推理部署工具包
|
||||
|
||||
## 最新活动
|
||||
**[2025-09] 🔥 FastDeploy v2.2 全新发布**: HuggingFace生态模型兼容,性能进一步优化,更新增对[baidu/ERNIE-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking)支持!
|
||||
|
||||
**[2025-08] FastDeploy v2.1 发布**:全新的KV Cache调度策略,更多模型支持PD分离和CUDA Graph,昆仑、海光等更多硬件支持增强,全方面优化服务和推理引擎的性能。
|
||||
|
||||
**[2025-07] 《FastDeploy2.0推理部署实测》专题活动已上线!** 完成文心4.5系列开源模型的推理部署等任务,即可获得骨瓷马克杯等FastDeploy2.0官方周边及丰富奖金!🎁 欢迎大家体验反馈~ 📌[报名地址](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[活动详情](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
|
||||
|
||||
## 关于
|
||||
|
||||
**FastDeploy** 是基于飞桨(PaddlePaddle)的大语言模型(LLM)与视觉语言模型(VLM)推理部署工具包,提供**开箱即用的生产级部署方案**,核心技术特性包括:
|
||||
|
||||
- 🚀 **负载均衡式PD分解**:工业级解决方案,支持上下文缓存与动态实例角色切换,在保障SLO达标和吞吐量的同时优化资源利用率
|
||||
- 🔄 **统一KV缓存传输**:轻量级高性能传输库,支持智能NVLink/RDMA选择
|
||||
- 🤝 **OpenAI API服务与vLLM兼容**:单命令部署,兼容[vLLM](https://github.com/vllm-project/vllm/)接口
|
||||
- 🧮 **全量化格式支持**:W8A16、W8A8、W4A16、W4A8、W2A16、FP8等
|
||||
- ⏩ **高级加速技术**:推测解码、多令牌预测(MTP)及分块预填充
|
||||
- 🖥️ **多硬件支持**:NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU等
|
||||
|
||||
## 要求
|
||||
|
||||
- 操作系统: Linux
|
||||
- Python: 3.10 ~ 3.12
|
||||
|
||||
## 安装
|
||||
|
||||
FastDeploy 支持在**英伟达(NVIDIA)GPU**、**昆仑芯(Kunlunxin)XPU**、**天数(Iluvatar)GPU**、**燧原(Enflame)GCU**、**海光(Hygon)DCU** 以及其他硬件上进行推理部署。详细安装说明如下:
|
||||
|
||||
- [英伟达 GPU](./docs/zh/get_started/installation/nvidia_gpu.md)
|
||||
- [昆仑芯 XPU](./docs/zh/get_started/installation/kunlunxin_xpu.md)
|
||||
- [天数 CoreX](./docs/zh/get_started/installation/iluvatar_gpu.md)
|
||||
- [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md)
|
||||
- [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md)
|
||||
- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md)
|
||||
|
||||
**注意:** 我们正在积极拓展硬件支持范围。目前,包括昇腾(Ascend)NPU 等其他硬件平台正在开发测试中。敬请关注更新!
|
||||
|
||||
## 入门指南
|
||||
|
||||
通过我们的文档了解如何使用 FastDeploy:
|
||||
- [10分钟快速部署](./docs/zh/get_started/quick_start.md)
|
||||
- [ERNIE-4.5 部署](./docs/zh/get_started/ernie-4.5.md)
|
||||
- [ERNIE-4.5-VL 部署](./docs/zh/get_started/ernie-4.5-vl.md)
|
||||
- [离线推理](./docs/zh/offline_inference.md)
|
||||
- [在线服务](./docs/zh/online_serving/README.md)
|
||||
- [最佳实践](./docs/zh/best_practices/README.md)
|
||||
|
||||
## 支持模型列表
|
||||
|
||||
通过我们的文档了解如何下载模型,如何支持torch格式等:
|
||||
- [模型支持列表](./docs/zh/supported_models.md)
|
||||
|
||||
## 进阶用法
|
||||
|
||||
- [量化](./docs/zh/quantization/README.md)
|
||||
- [分离式部署](./docs/zh/features/disaggregated.md)
|
||||
- [投机解码](./docs/zh/features/speculative_decoding.md)
|
||||
- [前缀缓存](./docs/zh/features/prefix_caching.md)
|
||||
- [分块预填充](./docs/zh/features/chunked_prefill.md)
|
||||
|
||||
## 致谢
|
||||
|
||||
FastDeploy 依据 [Apache-2.0 开源许可证](./LICENSE). 进行授权。在开发过程中,我们参考并借鉴了 [vLLM](https://github.com/vllm-project/vllm) 的部分代码,以保持接口兼容性,在此表示衷心感谢。
|
@@ -361,7 +361,8 @@ async def benchmark(
|
||||
|
||||
if not test_output.success:
|
||||
raise ValueError(
|
||||
f"Initial test run failed - Please make sure that 1. benchmark arguments are correctly specified and 2. the http_proxy and https_proxy are turned off. Error: {test_output.error}"
|
||||
"Initial test run failed - Please make sure benchmark arguments "
|
||||
f"are correctly specified. Error: {test_output.error}"
|
||||
)
|
||||
else:
|
||||
print("Initial test run completed. Starting main benchmark run...")
|
||||
|
@@ -6,4 +6,3 @@ tensor_parallel_size: 8
|
||||
max_num_batched_tokens: 4096
|
||||
max_num_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
||||
quantization: wint4
|
||||
|
@@ -1,6 +0,0 @@
|
||||
tensor_parallel_size: 1
|
||||
max_model_len: 131072
|
||||
max_num_seqs: 32
|
||||
quantization: wint4
|
||||
max_num_batched_tokens: 8192
|
||||
plas_attention_config: '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}'
|
@@ -6,4 +6,3 @@ tensor_parallel_size: 8
|
||||
max_num_batched_tokens: 4096
|
||||
max_num_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
||||
quantization: wint8
|
||||
|
@@ -1,5 +0,0 @@
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 256
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 4
|
||||
gpu_memory_utilization: 0.9
|
@@ -13,4 +13,3 @@ pd_comm_port: "2334"
|
||||
max_num_batched_tokens: 384
|
||||
max_num_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
||||
quantization: wint4
|
||||
|
@@ -10,4 +10,3 @@ engine_worker_queue_port: 6677
|
||||
cache_transfer_protocol: "rdma,ipc"
|
||||
rdma_comm_ports: "7675,7676,7677,7678"
|
||||
pd_comm_port: "2333"
|
||||
quantization: wint4
|
||||
|
@@ -1,6 +0,0 @@
|
||||
num_gpu_blocks_override: 1024
|
||||
max_model_len: 8192
|
||||
max_num_seqs: 64
|
||||
data_parallel_size: 8
|
||||
tensor_parallel_size: 1
|
||||
enable_expert_parallel: True
|
@@ -1,11 +0,0 @@
|
||||
enable_mm: True
|
||||
max_model_len: 131072
|
||||
max_num_seqs: 56
|
||||
gpu_memory_utilization: 0.8
|
||||
kv_cache_ratio: 0.8
|
||||
tensor_parallel_size: 8
|
||||
quantization: wint4
|
||||
limit_mm_per_prompt: '{"image": 100, "video": 100}'
|
||||
enable_chunked_prefill: True
|
||||
max_num_batched_tokens: 384
|
||||
reasoning_parser: ernie-45-vl
|
@@ -1,7 +1,7 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 36
|
||||
gpu_memory_utilization: 0.9
|
||||
gpu_memory_utilization: 0.95
|
||||
kv_cache_ratio: 0.8
|
||||
tensor_parallel_size: 8
|
||||
quantization: wint8
|
||||
|
@@ -1,7 +1,7 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 36
|
||||
gpu_memory_utilization: 0.85
|
||||
gpu_memory_utilization: 0.8
|
||||
kv_cache_ratio: 0.8
|
||||
tensor_parallel_size: 8
|
||||
quantization: wint8
|
||||
|
@@ -1,9 +0,0 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 1
|
||||
enable_chunked_prefill: True
|
||||
max_num_batched_tokens: 384
|
||||
reasoning_parser: ernie-45-vl
|
@@ -1,10 +0,0 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 1
|
||||
enable_chunked_prefill: True
|
||||
max_num_batched_tokens: 384
|
||||
quantization: wint4
|
||||
reasoning_parser: ernie-45-vl
|
@@ -1,10 +0,0 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 1
|
||||
enable_chunked_prefill: True
|
||||
max_num_batched_tokens: 384
|
||||
quantization: wint8
|
||||
reasoning_parser: ernie-45-vl
|
@@ -1 +0,0 @@
|
||||
max_tokens: 131071
|
@@ -1 +0,0 @@
|
||||
max_tokens: 12288
|
@@ -1,8 +0,0 @@
|
||||
top_p: 0.95
|
||||
temperature: 0.6
|
||||
metadata:
|
||||
min_tokens: 1
|
||||
max_tokens: 131071
|
||||
repetition_penalty: 1.0
|
||||
frequency_penalty: 0
|
||||
presence_penalty: 0
|
@@ -1,10 +0,0 @@
|
||||
reasoning-parser: ernie_x1
|
||||
tool_call_parser: ernie_x1
|
||||
tensor_parallel_size: 4
|
||||
max_model_len: 65536
|
||||
max_num_seqs: 128
|
||||
enable_prefix_caching: True
|
||||
enable_chunked_prefill: True
|
||||
gpu_memory_utilization: 0.85
|
||||
use_cudagraph: True
|
||||
enable_custom_all_reduce: True
|
@@ -1,6 +0,0 @@
|
||||
tensor_parallel_size: 1
|
||||
max_model_len: 131072
|
||||
max_num_seqs: 32
|
||||
reasoning_parser: ernie_x1
|
||||
tool_call_parser: ernie_x1
|
||||
load_choices: "default_v1"
|
34
build.sh
34
build.sh
@@ -34,6 +34,7 @@ EGG_DIR="fastdeploy.egg-info"
|
||||
|
||||
# custom_ops directory config
|
||||
OPS_SRC_DIR="custom_ops"
|
||||
OPS_TMP_DIR_BASE="tmp_base"
|
||||
OPS_TMP_DIR="tmp"
|
||||
|
||||
# command line log config
|
||||
@@ -70,20 +71,25 @@ function copy_ops(){
|
||||
PY_VERSION="py${PY_MAIN_VERSION}.${PY_SUB_VERSION}"
|
||||
SYSTEM_VERSION=`${python} -c "import platform; print(platform.system().lower())"`
|
||||
PROCESSOR_VERSION=`${python} -c "import platform; print(platform.processor())"`
|
||||
WHEEL_BASE_NAME="fastdeploy_base_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg"
|
||||
WHEEL_NAME="fastdeploy_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg"
|
||||
WHEEL_CPU_NAME="fastdeploy_cpu_ops-${OPS_VERSION}-${PY_VERSION}-${SYSTEM_VERSION}-${PROCESSOR_VERSION}.egg"
|
||||
is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"`
|
||||
if [ "$is_rocm" = "True" ]; then
|
||||
DEVICE_TYPE="rocm"
|
||||
mkdir -p ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
|
||||
echo -e "ROCM ops have been copy to fastdeploy"
|
||||
echo -e "BASE and ROCM ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
mkdir -p ../fastdeploy/model_executor/ops/base
|
||||
is_cuda=`$python -c "import paddle; print(paddle.is_compiled_with_cuda())"`
|
||||
if [ "$is_cuda" = "True" ]; then
|
||||
DEVICE_TYPE="gpu"
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
|
||||
echo -e "CUDA ops have been copy to fastdeploy"
|
||||
echo -e "BASE and CUDA ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
@@ -106,8 +112,9 @@ function copy_ops(){
|
||||
if_corex=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device(\"iluvatar_gpu\"))"`
|
||||
if [ "$if_corex" = "True" ]; then
|
||||
DEVICE_TYPE="iluvatar-gpu"
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/iluvatar
|
||||
echo -e "Iluvatar ops have been copy to fastdeploy"
|
||||
echo -e "BASE and Iluvatar ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
@@ -119,33 +126,27 @@ function copy_ops(){
|
||||
return
|
||||
fi
|
||||
|
||||
is_maca=`$python -c "import paddle; print(paddle.device.is_compiled_with_custom_device('metax_gpu'))"`
|
||||
if [ "$is_maca" = "True" ]; then
|
||||
DEVICE_TYPE="metax_gpu"
|
||||
mkdir -p ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
|
||||
echo -e "MACA ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
DEVICE_TYPE="cpu"
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cd ../../../../
|
||||
cp -r ${OPS_TMP_DIR}/${WHEEL_CPU_NAME}/* ../fastdeploy/model_executor/ops/cpu
|
||||
echo -e "CPU ops have been copy to fastdeploy"
|
||||
echo -e "BASE and CPU ops have been copy to fastdeploy"
|
||||
return
|
||||
}
|
||||
|
||||
function build_and_install_ops() {
|
||||
cd $OPS_SRC_DIR
|
||||
export no_proxy=bcebos.com,paddlepaddle.org.cn,${no_proxy}
|
||||
echo -e "${BLUE}[build]${NONE} build and install fastdeploy_base_ops..."
|
||||
${python} setup_ops_base.py install --install-lib ${OPS_TMP_DIR_BASE}
|
||||
find ${OPS_TMP_DIR_BASE} -type f -name "*.o" -exec rm -f {} \;
|
||||
echo -e "${BLUE}[build]${NONE} build and install fastdeploy_ops..."
|
||||
TMP_DIR_REAL_PATH=`readlink -f ${OPS_TMP_DIR}`
|
||||
is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"`
|
||||
if [ "$is_xpu" = "True" ]; then
|
||||
cd xpu_ops
|
||||
cd xpu_ops/src
|
||||
bash build.sh ${TMP_DIR_REAL_PATH}
|
||||
cd ..
|
||||
cd ../..
|
||||
elif [ "$FD_CPU_USE_BF16" == "true" ]; then
|
||||
if [ "$FD_BUILDING_ARCS" == "" ]; then
|
||||
FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
|
||||
@@ -212,6 +213,7 @@ function cleanup() {
|
||||
fi
|
||||
|
||||
rm -rf $OPS_SRC_DIR/$BUILD_DIR $OPS_SRC_DIR/$EGG_DIR
|
||||
rm -rf $OPS_SRC_DIR/$OPS_TMP_DIR_BASE
|
||||
rm -rf $OPS_SRC_DIR/$OPS_TMP_DIR
|
||||
}
|
||||
|
||||
|
@@ -84,6 +84,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||
seq_length,
|
||||
bsz);
|
||||
return {x_remove_padding,
|
||||
cum_offsets_out,
|
||||
padding_offset,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k};
|
||||
@@ -96,7 +97,7 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
|
||||
const std::vector<int64_t> &seq_len_shape) {
|
||||
int64_t bsz = seq_len_shape[0];
|
||||
int64_t seq_len = input_ids_shape[1];
|
||||
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
|
||||
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
@@ -105,6 +106,7 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
const paddle::DataType &token_num_dtype,
|
||||
const paddle::DataType &seq_len_dtype) {
|
||||
return {input_ids_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype,
|
||||
seq_len_dtype};
|
||||
@@ -113,6 +115,7 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
PD_BUILD_STATIC_OP(get_padding_offset_cpu)
|
||||
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
|
||||
.Outputs({"x_remove_padding",
|
||||
"cum_offsets_out",
|
||||
"padding_offset",
|
||||
"cu_seqlens_q",
|
||||
"cu_seqlens_k"})
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -19,11 +19,10 @@
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
|
||||
template <typename T>
|
||||
void RebuildPaddingCPUImpl(T *output_data,
|
||||
const T *input_data,
|
||||
const int *cu_seqlens_q_data,
|
||||
const int *cum_offsets_data,
|
||||
const int *seq_len_this_time_data,
|
||||
const int *seq_lens_decoder_data,
|
||||
const int *seq_lens_encoder_data,
|
||||
@@ -41,12 +40,11 @@ void RebuildPaddingCPUImpl(T *output_data,
|
||||
if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (seq_lens_encoder_data[bi] > 0) {
|
||||
seq_id = seq_lens_encoder_data[bi] - 1;
|
||||
}
|
||||
|
||||
const int ori_token_idx = cu_seqlens_q_data[bi] + seq_id;
|
||||
const int ori_token_idx =
|
||||
bi * max_input_length - cum_offsets_data[bi] + seq_id;
|
||||
const int src_offset = ori_token_idx * dim_embed + bias_idx;
|
||||
|
||||
output_data[i] = input_data[src_offset];
|
||||
@@ -56,7 +54,7 @@ void RebuildPaddingCPUImpl(T *output_data,
|
||||
template <typename T>
|
||||
void RebuildAppendPaddingCPUImpl(T *output_data,
|
||||
const T *input_data,
|
||||
const int *cu_seqlens_q_data,
|
||||
const int *cum_offsets_data,
|
||||
const int *seq_len_this_time_data,
|
||||
const int *seq_lens_decoder_data,
|
||||
const int *seq_lens_encoder_data,
|
||||
@@ -71,32 +69,30 @@ void RebuildAppendPaddingCPUImpl(T *output_data,
|
||||
int bi = ori_token_id / max_input_length;
|
||||
if (seq_len_this_time_data[bi] == 0 ||
|
||||
(seq_lens_decoder_data[bi] == 0 &&
|
||||
seq_lens_encoder_data[bi] == 0)) {
|
||||
continue;
|
||||
}
|
||||
seq_lens_encoder_data[bi] == 0)) {
|
||||
continue;
|
||||
}
|
||||
int seq_id = 0;
|
||||
|
||||
if (seq_lens_encoder_data[bi] > 0) {
|
||||
seq_id = seq_lens_encoder_data[bi] - 1;
|
||||
}
|
||||
int input_token_id = cu_seqlens_q_data[bi] + seq_id;
|
||||
int input_token_id = ori_token_id - cum_offsets_data[bi] + seq_id;
|
||||
int bias_idx = i % dim_embed;
|
||||
int src_offset = input_token_id * dim_embed + bias_idx;
|
||||
|
||||
output_data[i] = input_data[src_offset];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
const paddle::Tensor &tmp_out,
|
||||
const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &cum_offsets,
|
||||
const paddle::Tensor &seq_len_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||
int max_input_length) {
|
||||
auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true);
|
||||
auto cu_seqlens_q_cpu = cu_seqlens_q.copy_to(paddle::CPUPlace(), true);
|
||||
auto cum_offsets_cpu = cum_offsets.copy_to(paddle::CPUPlace(), true);
|
||||
auto seq_len_this_time_cpu =
|
||||
seq_len_this_time.copy_to(paddle::CPUPlace(), true);
|
||||
auto seq_lens_decoder_cpu =
|
||||
@@ -111,7 +107,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
|
||||
int token_num = tmp_out_cpu.shape()[0];
|
||||
int dim_embed = tmp_out_cpu.shape()[1];
|
||||
int bsz = cu_seqlens_q_cpu.shape()[0] - 1;
|
||||
int bsz = cum_offsets_cpu.shape()[0];
|
||||
|
||||
paddle::Tensor out;
|
||||
if (output_padding_offset_cpu) {
|
||||
@@ -132,7 +128,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
{bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace());
|
||||
}
|
||||
|
||||
const int *cu_seqlens_q_data = cu_seqlens_q_cpu.data<int>();
|
||||
const int *cum_offsets_data = cum_offsets_cpu.data<int>();
|
||||
const int *seq_len_this_time_data = seq_len_this_time_cpu.data<int>();
|
||||
const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
|
||||
const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data<int>();
|
||||
@@ -145,7 +141,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
case paddle::DataType::FLOAT32:
|
||||
RebuildAppendPaddingCPUImpl<float>(out.data<float>(),
|
||||
tmp_out_cpu.data<float>(),
|
||||
cu_seqlens_q_data,
|
||||
cum_offsets_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
@@ -158,7 +154,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
RebuildAppendPaddingCPUImpl<paddle::float16>(
|
||||
out.data<paddle::float16>(),
|
||||
tmp_out_cpu.data<paddle::float16>(),
|
||||
cu_seqlens_q_data,
|
||||
cum_offsets_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
@@ -171,7 +167,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
RebuildAppendPaddingCPUImpl<paddle::bfloat16>(
|
||||
out.data<paddle::bfloat16>(),
|
||||
tmp_out_cpu.data<paddle::bfloat16>(),
|
||||
cu_seqlens_q_data,
|
||||
cum_offsets_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
@@ -190,7 +186,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
case paddle::DataType::FLOAT32:
|
||||
RebuildPaddingCPUImpl<float>(out.data<float>(),
|
||||
tmp_out_cpu.data<float>(),
|
||||
cu_seqlens_q_data,
|
||||
cum_offsets_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
@@ -202,7 +198,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
RebuildPaddingCPUImpl<paddle::float16>(
|
||||
out.data<paddle::float16>(),
|
||||
tmp_out_cpu.data<paddle::float16>(),
|
||||
cu_seqlens_q_data,
|
||||
cum_offsets_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
@@ -211,10 +207,11 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
elem_nums);
|
||||
break;
|
||||
case paddle::DataType::BFLOAT16:
|
||||
|
||||
RebuildPaddingCPUImpl<paddle::bfloat16>(
|
||||
out.data<paddle::bfloat16>(),
|
||||
tmp_out_cpu.data<paddle::bfloat16>(),
|
||||
cu_seqlens_q_data,
|
||||
cum_offsets_data,
|
||||
seq_len_this_time_data,
|
||||
seq_lens_decoder_data,
|
||||
seq_lens_encoder_data,
|
||||
@@ -233,7 +230,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||
|
||||
std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
|
||||
const std::vector<int64_t> &tmp_out_shape,
|
||||
const std::vector<int64_t> &cu_seqlens_q_shape,
|
||||
const std::vector<int64_t> &cum_offsets_shape,
|
||||
const std::vector<int64_t> &seq_len_this_time_shape,
|
||||
const std::vector<int64_t> &seq_lens_decoder_shape,
|
||||
const std::vector<int64_t> &seq_lens_encoder_shape,
|
||||
@@ -242,14 +239,14 @@ std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
|
||||
if (output_padding_offset_shape) {
|
||||
return {{-1, dim_embed}};
|
||||
} else {
|
||||
int64_t bsz = cu_seqlens_q_shape[0] - 1;
|
||||
int64_t bsz = cum_offsets_shape[0];
|
||||
return {{bsz, dim_embed}};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> RebuildPaddingInferDtype(
|
||||
const paddle::DataType &tmp_out_dtype,
|
||||
const paddle::DataType &cu_seqlens_q_dtype,
|
||||
const paddle::DataType &cum_offsets_dtype,
|
||||
const paddle::DataType &seq_len_this_time_dtype,
|
||||
const paddle::DataType &seq_lens_decoder_dtype,
|
||||
const paddle::DataType &seq_lens_encoder_dtype,
|
||||
@@ -259,7 +256,7 @@ std::vector<paddle::DataType> RebuildPaddingInferDtype(
|
||||
|
||||
PD_BUILD_STATIC_OP(rebuild_padding_cpu)
|
||||
.Inputs({"tmp_out",
|
||||
"cu_seqlens_q",
|
||||
"cum_offsets",
|
||||
"seq_len_this_time",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_encoder",
|
||||
|
@@ -14,7 +14,7 @@
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
void set_value_by_flags_and_idx(const bool *stop_flags,
|
||||
void set_value_by_flag_and_id(const bool *stop_flags,
|
||||
int64_t *pre_ids_all,
|
||||
const int64_t *input_ids,
|
||||
const int *seq_lens_encoder,
|
||||
@@ -50,7 +50,7 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
||||
int length = pre_ids_all_shape[1];
|
||||
int length_input_ids = input_ids.shape()[1];
|
||||
|
||||
set_value_by_flags_and_idx(stop_flags.data<bool>(),
|
||||
set_value_by_flag_and_id(stop_flags.data<bool>(),
|
||||
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
|
||||
input_ids.data<int64_t>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
|
@@ -46,7 +46,7 @@ void update_inputs_kernel(bool *not_need_stop,
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
}
|
||||
|
||||
void UpdateInputs(const paddle::Tensor &stop_flags,
|
||||
void UpdateInputes(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
@@ -90,4 +90,4 @@ PD_BUILD_STATIC_OP(update_inputs_cpu)
|
||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"input_ids", "input_ids_out"}})
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputs));
|
||||
.SetKernelFn(PD_KERNEL(UpdateInputes));
|
||||
|
@@ -38,7 +38,7 @@ class type2value<phi::dtype::float16> {
|
||||
|
||||
|
||||
template <paddle::DataType D>
|
||||
void AppendAttentionKernel(
|
||||
std::vector<paddle::Tensor> AppendAttentionKernel(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& key_cache,
|
||||
@@ -60,7 +60,6 @@ void AppendAttentionKernel(
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& set_max_lengths,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
paddle::Tensor& fmha_out,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||
@@ -73,11 +72,7 @@ void AppendAttentionKernel(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const paddle::optional<paddle::Tensor>& mask_offset,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
@@ -123,6 +118,27 @@ void AppendAttentionKernel(
|
||||
} else {
|
||||
qkv_out = qkv;
|
||||
}
|
||||
paddle::Tensor fmha_out;
|
||||
if (out_linear_in_scale > 0.0) {
|
||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||
fmha_out = GetEmptyTensor(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
paddle::DataType::INT8,
|
||||
qkv.place());
|
||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||
fmha_out = GetEmptyTensor(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
paddle::DataType::FLOAT8_E4M3FN,
|
||||
qkv.place());
|
||||
}else{
|
||||
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
|
||||
}
|
||||
} else {
|
||||
fmha_out = GetEmptyTensor(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
D,
|
||||
qkv.place());
|
||||
}
|
||||
|
||||
auto dispatch_CascadeAppendAttentionKernel = [&](auto temp_args,
|
||||
const paddle::Tensor& lambda_batch_ids,
|
||||
@@ -140,8 +156,8 @@ void AppendAttentionKernel(
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_mask,
|
||||
cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales : cache_k_dequant_scales,
|
||||
cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales : cache_v_dequant_scales,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
@@ -207,10 +223,7 @@ void AppendAttentionKernel(
|
||||
main_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
};
|
||||
|
||||
if (qkv_out_scales) {
|
||||
@@ -257,6 +270,54 @@ void AppendAttentionKernel(
|
||||
if (speculate_decoder) {
|
||||
if (qkv_out_scales) {
|
||||
SpeculateWriteCacheWithRoPEKernel<data_t, int>(
|
||||
meta_data,
|
||||
qkv, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
} else {
|
||||
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
|
||||
meta_data,
|
||||
qkv_out, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
DecoderWriteCacheWithRoPEKernel<data_t, int>(
|
||||
meta_data,
|
||||
qkv, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
@@ -278,12 +339,9 @@ void AppendAttentionKernel(
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
} else {
|
||||
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
|
||||
DecoderWriteCacheWithRoPEKernel<data_t, data_t>(
|
||||
meta_data,
|
||||
qkv_out, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
@@ -305,64 +363,7 @@ void AppendAttentionKernel(
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
DecoderWriteCacheWithRoPEKernel<data_t, int>(
|
||||
meta_data,
|
||||
qkv, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
DecoderWriteCacheWithRoPEKernel<data_t, data_t>(
|
||||
meta_data,
|
||||
qkv_out, // [token_num, num_heads, head_dim]
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
rotary_embs,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,6 +392,8 @@ void AppendAttentionKernel(
|
||||
cudaStreamWaitEvent(main_stream, decoder_event);
|
||||
}
|
||||
}
|
||||
|
||||
return {fmha_out, qkv_out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> AppendAttention(
|
||||
@@ -426,11 +429,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const paddle::optional<paddle::Tensor>& mask_offset,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
@@ -465,60 +464,8 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
meta_data.block_size = key_cache.dims()[2];
|
||||
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||
|
||||
// template dtype generation
|
||||
phi::DataType dtype_id;
|
||||
switch (qkv.dtype()) {
|
||||
case paddle::DataType::FLOAT16: {dtype_id = phi::DataType::FLOAT16; break;}
|
||||
case paddle::DataType::BFLOAT16: {dtype_id = phi::DataType::BFLOAT16; break;}
|
||||
case paddle::DataType::INT32: {
|
||||
if (compute_dtype == "bf16") {
|
||||
dtype_id = phi::DataType::BFLOAT16;
|
||||
break;
|
||||
} else if (compute_dtype == "fp16") {
|
||||
dtype_id = phi::DataType::FLOAT16;
|
||||
break;
|
||||
} else {
|
||||
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
|
||||
break;
|
||||
}
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16 and bfloat16 are supported. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// fmha_out generation, rewrite from AppendAttentionKernel
|
||||
paddle::Tensor fmha_out;
|
||||
if (out_linear_in_scale > 0.0) {
|
||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||
fmha_out = GetEmptyTensor(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
paddle::DataType::INT8,
|
||||
qkv.place());
|
||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||
fmha_out = GetEmptyTensor(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
paddle::DataType::FLOAT8_E4M3FN,
|
||||
qkv.place());
|
||||
} else{
|
||||
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
|
||||
}
|
||||
} else {
|
||||
fmha_out = GetEmptyTensor(
|
||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||
dtype_id,
|
||||
qkv.place());
|
||||
}
|
||||
|
||||
if (mask_offset) {
|
||||
meta_data.mask_offset = mask_offset.get().data<int>();
|
||||
}
|
||||
|
||||
auto dispatch_by_template = [&](auto temp_args) -> void {
|
||||
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||
auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
|
||||
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||
meta_data,
|
||||
qkv,
|
||||
key_cache,
|
||||
@@ -540,7 +487,6 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
decoder_num_blocks,
|
||||
set_max_lengths,
|
||||
max_len_kv,
|
||||
fmha_out,
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
qkv_bias,
|
||||
@@ -553,11 +499,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
mask_offset,
|
||||
kv_signal_data,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
@@ -572,198 +514,35 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder);
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
phi::dtype::float16 fp16_dtype;
|
||||
phi::dtype::bfloat16 bp16_dtype;
|
||||
switch (dtype_id){
|
||||
case phi::DataType::FLOAT16: {
|
||||
dispatch_by_template(fp16_dtype);
|
||||
return {fmha_out};
|
||||
}
|
||||
case phi::DataType::BFLOAT16: {
|
||||
dispatch_by_template(bp16_dtype);
|
||||
return {fmha_out};
|
||||
}
|
||||
default:
|
||||
PD_THROW(
|
||||
|
||||
switch (qkv.dtype()) {
|
||||
case paddle::DataType::FLOAT16: return dispatch_by_template(fp16_dtype);
|
||||
case paddle::DataType::BFLOAT16: return dispatch_by_template(bp16_dtype);
|
||||
case paddle::DataType::INT32: {
|
||||
if (compute_dtype == "bf16") {
|
||||
return dispatch_by_template(bp16_dtype);
|
||||
} else if (compute_dtype == "fp16") {
|
||||
return dispatch_by_template(fp16_dtype);
|
||||
} else {
|
||||
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
|
||||
break;
|
||||
}
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16 and bfloat16 are supported. ");
|
||||
break;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return {paddle::Tensor{}};
|
||||
}
|
||||
|
||||
void AppendAttentionWithOutput(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& set_max_lengths,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
paddle::Tensor& fmha_out,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||
const paddle::optional<paddle::Tensor>& qkv_out_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_zp,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||
const paddle::optional<paddle::Tensor>& mask_offset,
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_input_length,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int max_partition_size,
|
||||
const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
AppendAttnMetaData meta_data;
|
||||
|
||||
const auto& qkv_dims = qkv.dims();
|
||||
const auto& key_cache_dims = key_cache.dims();
|
||||
meta_data.token_nums = qkv_dims[0];
|
||||
meta_data.kv_num_heads = key_cache_dims[1];
|
||||
meta_data.head_dims = key_cache_dims[3];
|
||||
// TODO: trick method support c4, add attr head_dims in the future
|
||||
if (cache_quant_type_str == "cache_int4_zp") {
|
||||
meta_data.head_dims *= 2;
|
||||
}
|
||||
const int total_num_head =
|
||||
qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims;
|
||||
meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads;
|
||||
|
||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||
meta_data.block_size = key_cache.dims()[2];
|
||||
meta_data.batch_size = seq_lens_this_time.dims()[0];
|
||||
|
||||
if (mask_offset) {
|
||||
meta_data.mask_offset = mask_offset.get().data<int>();
|
||||
}
|
||||
|
||||
auto dispatch_by_template = [&](auto temp_args) -> void {
|
||||
AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
|
||||
meta_data,
|
||||
qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
set_max_lengths,
|
||||
max_len_kv,
|
||||
fmha_out,
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
qkv_bias,
|
||||
qkv_out_scales,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
out_linear_smooths,
|
||||
mask_offset,
|
||||
kv_signal_data,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
out_linear_in_scale,
|
||||
encoder_block_shape_q,
|
||||
decoder_block_shape_q,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
causal,
|
||||
speculate_decoder);
|
||||
};
|
||||
|
||||
phi::dtype::float16 fp16_dtype;
|
||||
phi::dtype::bfloat16 bp16_dtype;
|
||||
|
||||
switch (qkv.dtype()) {
|
||||
case paddle::DataType::FLOAT16: {
|
||||
dispatch_by_template(fp16_dtype);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
dispatch_by_template(bp16_dtype);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::INT32: {
|
||||
if (compute_dtype == "bf16") {
|
||||
dispatch_by_template(bp16_dtype);
|
||||
break;
|
||||
} else if (compute_dtype == "fp16") {
|
||||
dispatch_by_template(fp16_dtype);
|
||||
break;
|
||||
} else {
|
||||
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
|
||||
break;
|
||||
}
|
||||
}
|
||||
default: {
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only float16 and bfloat16 are supported. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
||||
const std::vector<int64_t>& qkv_shape,
|
||||
const std::vector<int64_t>& key_cache_shape,
|
||||
@@ -797,11 +576,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& mask_offset_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
@@ -825,7 +600,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
||||
}
|
||||
const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim;
|
||||
const int num_heads = total_num_head - 2 * kv_num_heads;
|
||||
return {{token_num, num_heads * head_dim}};
|
||||
return {{token_num, num_heads * head_dim}, qkv_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||
@@ -861,11 +636,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
|
||||
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
|
||||
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
|
||||
const paddle::optional<paddle::DataType>& mask_offset_dtype,
|
||||
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
|
||||
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
@@ -884,148 +655,32 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
||||
if (compute_dtype == "bf16") {
|
||||
if (out_linear_in_scale > 0.0) {
|
||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||
return {paddle::DataType::INT8};
|
||||
return {paddle::DataType::INT8, paddle::DataType::BFLOAT16};
|
||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||
return {paddle::DataType::FLOAT8_E4M3FN};
|
||||
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::BFLOAT16};
|
||||
}else{
|
||||
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
|
||||
}
|
||||
} else {
|
||||
return {paddle::DataType::BFLOAT16};
|
||||
return {paddle::DataType::BFLOAT16, paddle::DataType::BFLOAT16};
|
||||
}
|
||||
} else if (compute_dtype == "fp16") {
|
||||
if (out_linear_in_scale > 0.0) {
|
||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||
return {paddle::DataType::INT8};
|
||||
return {paddle::DataType::INT8, paddle::DataType::FLOAT16};
|
||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||
return {paddle::DataType::FLOAT8_E4M3FN};
|
||||
return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT16};
|
||||
}else{
|
||||
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
|
||||
}
|
||||
} else {
|
||||
return {paddle::DataType::FLOAT16};
|
||||
return {paddle::DataType::FLOAT16, paddle::DataType::FLOAT16};
|
||||
}
|
||||
} else {
|
||||
PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16'].");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
|
||||
const std::vector<int64_t>& qkv_shape,
|
||||
const std::vector<int64_t>& key_cache_shape,
|
||||
const std::vector<int64_t>& value_cache_shape,
|
||||
const std::vector<int64_t>& seq_lens_encoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& batch_id_per_token_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||
const std::vector<int64_t>& block_tables_shape,
|
||||
const std::vector<int64_t>& encoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& encoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& kv_batch_ids_shape,
|
||||
const std::vector<int64_t>& kv_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& kv_num_blocks_shape,
|
||||
const std::vector<int64_t>& decoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& set_max_lengths_shape,
|
||||
const std::vector<int64_t>& max_len_kv_shape,
|
||||
const std::vector<int64_t>& fmha_out_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& qkv_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& qkv_out_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_k_quant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_quant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_k_dequant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_dequant_scales_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_k_zp_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& cache_v_zp_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& mask_offset_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_input_length,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int max_partition_size,
|
||||
const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
return {fmha_out_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
|
||||
const paddle::DataType& qkv_dtype,
|
||||
const paddle::DataType& key_cache_dtype,
|
||||
const paddle::DataType& value_cache_dtype,
|
||||
const paddle::DataType& seq_lens_encoder_dtype,
|
||||
const paddle::DataType& seq_lens_decoder_dtype,
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& batch_id_per_token_dtype,
|
||||
const paddle::DataType& cu_seqlens_q_dtype,
|
||||
const paddle::DataType& block_tables_dtype,
|
||||
const paddle::DataType& encoder_batch_ids_dtype,
|
||||
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& encoder_num_blocks_dtype,
|
||||
const paddle::DataType& kv_batch_ids_dtype,
|
||||
const paddle::DataType& kv_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& kv_num_blocks_dtype,
|
||||
const paddle::DataType& decoder_batch_ids_dtype,
|
||||
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_dtype,
|
||||
const paddle::DataType& set_max_lengths_dtype,
|
||||
const paddle::DataType& max_len_kv_dtype,
|
||||
const paddle::DataType& fmha_out_dtype,
|
||||
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
|
||||
const paddle::optional<paddle::DataType>& attn_mask_dtype,
|
||||
const paddle::optional<paddle::DataType>& qkv_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& qkv_out_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_k_quant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_v_quant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_k_dequant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_v_dequant_scales_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_k_zp_dtype,
|
||||
const paddle::optional<paddle::DataType>& cache_v_zp_dtype,
|
||||
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
|
||||
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
|
||||
const paddle::optional<paddle::DataType>& mask_offset_dtype,
|
||||
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
|
||||
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
|
||||
const float rms_norm_eps,
|
||||
const std::string& compute_dtype,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_input_length,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
const float out_linear_in_scale,
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int max_partition_size,
|
||||
const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num,
|
||||
const bool causal,
|
||||
const bool speculate_decoder) {
|
||||
return {fmha_out_dtype};
|
||||
}
|
||||
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(append_attention)
|
||||
.Inputs({"qkv",
|
||||
"key_cache",
|
||||
@@ -1059,15 +714,11 @@ PD_BUILD_STATIC_OP(append_attention)
|
||||
paddle::Optional("cache_v_zp"),
|
||||
paddle::Optional("out_linear_shifts"),
|
||||
paddle::Optional("out_linear_smooths"),
|
||||
paddle::Optional("mask_offset"),
|
||||
paddle::Optional("kv_signal_data"),
|
||||
paddle::Optional("q_norm_weight"),
|
||||
paddle::Optional("k_norm_weight")})
|
||||
.Outputs({"fmha_out", "key_cache_out", "value_cache_out"})
|
||||
paddle::Optional("kv_signal_data")})
|
||||
.Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})
|
||||
.SetInplaceMap({{"key_cache", "key_cache_out"},
|
||||
{"value_cache", "value_cache_out"}})
|
||||
.Attrs({"rms_norm_eps: float",
|
||||
"compute_type: std::string",
|
||||
.Attrs({"compute_type: std::string",
|
||||
"cache_quant_type: std::string",
|
||||
"use_neox_rotary_style: bool",
|
||||
"rope_3d: bool",
|
||||
@@ -1081,71 +732,7 @@ PD_BUILD_STATIC_OP(append_attention)
|
||||
"encoder_max_partition_size: int",
|
||||
"speculate_max_draft_token_num: int",
|
||||
"causal: bool",
|
||||
"speculate_decoder: bool",
|
||||
})
|
||||
"speculate_decoder: bool"})
|
||||
.SetKernelFn(PD_KERNEL(AppendAttention))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));
|
||||
|
||||
PD_BUILD_STATIC_OP(append_attention_with_output)
|
||||
.Inputs({"qkv",
|
||||
"key_cache",
|
||||
"value_cache",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
"block_tables",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
"encoder_num_blocks",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
"kv_num_blocks",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks",
|
||||
"set_max_lengths",
|
||||
"max_len_kv",
|
||||
"fmha_out",
|
||||
paddle::Optional("rotary_embs"),
|
||||
paddle::Optional("attn_mask"),
|
||||
paddle::Optional("qkv_bias"),
|
||||
paddle::Optional("qkv_out_scales"),
|
||||
paddle::Optional("cache_k_quant_scales"),
|
||||
paddle::Optional("cache_v_quant_scales"),
|
||||
paddle::Optional("cache_k_dequant_scales"),
|
||||
paddle::Optional("cache_v_dequant_scales"),
|
||||
paddle::Optional("cache_k_zp"),
|
||||
paddle::Optional("cache_v_zp"),
|
||||
paddle::Optional("out_linear_shifts"),
|
||||
paddle::Optional("out_linear_smooths"),
|
||||
paddle::Optional("mask_offset"),
|
||||
paddle::Optional("kv_signal_data"),
|
||||
paddle::Optional("q_norm_weight"),
|
||||
paddle::Optional("k_norm_weight")})
|
||||
.Outputs({"fmha_out_out", "qkv_out", "key_cache_out", "value_cache_out"})
|
||||
.SetInplaceMap({{"fmha_out", "fmha_out_out"},
|
||||
{"key_cache", "key_cache_out"},
|
||||
{"value_cache", "value_cache_out"}})
|
||||
.Attrs({"rms_norm_eps: float",
|
||||
"compute_type: std::string",
|
||||
"cache_quant_type: std::string",
|
||||
"use_neox_rotary_style: bool",
|
||||
"rope_3d: bool",
|
||||
"max_input_length: int",
|
||||
"quant_max_bound: float",
|
||||
"quant_min_bound: float",
|
||||
"out_linear_in_scale: float",
|
||||
"encoder_block_shape_q: int",
|
||||
"decoder_block_shape_q: int",
|
||||
"max_partition_size: int",
|
||||
"encoder_max_partition_size: int",
|
||||
"speculate_max_draft_token_num: int",
|
||||
"causal: bool",
|
||||
"speculate_decoder: bool",
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(AppendAttentionWithOutput))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionWithOutputInferDtype));
|
||||
|
@@ -43,7 +43,6 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int *__restrict__ mask_offset,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
@@ -52,7 +51,6 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -75,11 +73,6 @@ __global__ void multi_query_append_attention_kernel(
|
||||
|
||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -148,7 +141,6 @@ __global__ void multi_query_append_attention_kernel(
|
||||
} else {
|
||||
o_base_ptr_int8 = out + o_offset;
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -187,7 +179,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
kv_len - q_len +
|
||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||
chunk_start)))
|
||||
: mask_offset ? 0 : chunk_len) /
|
||||
: chunk_len) /
|
||||
(num_frags_z * 16);
|
||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
||||
@@ -253,16 +245,12 @@ __global__ void multi_query_append_attention_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
num_frags_z>(q_base_seq_id_this_block,
|
||||
kv_idx_base,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
-1,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
|
||||
s_frag);
|
||||
}
|
||||
|
||||
// update m,d
|
||||
@@ -418,8 +406,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int *__restrict__ mask_offset,
|
||||
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
@@ -428,14 +414,12 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const int speculate_max_draft_token_num = 5,
|
||||
const int32_t attn_mask_len = -1) {
|
||||
const int speculate_max_draft_token_num = 5) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1");
|
||||
static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4");
|
||||
@@ -452,11 +436,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -523,7 +502,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -561,9 +540,10 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const uint32_t mask_check_iteration =
|
||||
(CAUSAL ? (min(chunk_len,
|
||||
sub_if_greater_or_zero(
|
||||
kv_len - q_len,
|
||||
kv_len - q_len +
|
||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||
chunk_start)))
|
||||
: mask_offset ? 0 : chunk_len) /
|
||||
: chunk_len) /
|
||||
(NUM_WARP_KV * num_frags_z * 16);
|
||||
|
||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -631,15 +611,12 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
num_frags_z>(q_base_seq_id_this_block,
|
||||
kv_idx_base + wid * num_frags_z * 16,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
attn_mask_len,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
s_frag);
|
||||
}
|
||||
|
||||
// update m,d
|
||||
@@ -905,7 +882,6 @@ void MultiQueryAppendAttention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -914,7 +890,6 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -964,7 +939,6 @@ void MultiQueryAppendAttention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -973,7 +947,6 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
@@ -1088,18 +1061,12 @@ void MultiQueryAppendAttention(
|
||||
if (!is_decoder) {
|
||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||
}
|
||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
|
||||
int32_t attn_mask_len;
|
||||
if (attn_mask) {
|
||||
attn_mask_len = attn_mask.get().shape()[1];
|
||||
} else {
|
||||
attn_mask_len = -1;
|
||||
}
|
||||
|
||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||
dim3 blocks(32, num_warps);
|
||||
if (num_chunks <= 0) {
|
||||
|
||||
if (num_chunks <= 1) {
|
||||
auto nosplit_kv_kernel =
|
||||
multi_query_append_attention_warp1_4_kernel<NV_TYPE,
|
||||
false,
|
||||
@@ -1137,9 +1104,6 @@ void MultiQueryAppendAttention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1148,13 +1112,11 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
speculate_max_draft_token_num);
|
||||
} else {
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
if (is_decoder) {
|
||||
@@ -1199,8 +1161,8 @@ void MultiQueryAppendAttention(
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1210,9 +1172,6 @@ void MultiQueryAppendAttention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1221,13 +1180,11 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
speculate_max_draft_token_num);
|
||||
|
||||
// merge
|
||||
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
||||
@@ -1251,8 +1208,8 @@ void MultiQueryAppendAttention(
|
||||
seq_lens_encoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1269,14 +1226,14 @@ void MultiQueryAppendAttention(
|
||||
constexpr int blockx = HEAD_DIM / vec_size;
|
||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
||||
dim3 grids_merge(min(sm_count * 4, token_num),
|
||||
num_heads);
|
||||
num_heads);
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_v2_kernel<NV_TYPE,
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
<<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
@@ -1287,8 +1244,8 @@ void MultiQueryAppendAttention(
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
|
@@ -48,7 +48,6 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int *__restrict__ mask_offset,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
@@ -57,7 +56,6 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -86,11 +84,6 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
|
||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -179,7 +172,6 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
} else {
|
||||
o_base_ptr_int8 = out + o_offset;
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -256,7 +248,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
kv_len - q_len +
|
||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||
chunk_start)))
|
||||
: mask_offset ? 0 : chunk_len) /
|
||||
: chunk_len) /
|
||||
(num_frags_z * 16);
|
||||
|
||||
uint32_t k_smem_offset_r =
|
||||
@@ -341,15 +333,12 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
num_frags_z>(q_base_seq_id_this_block,
|
||||
kv_idx_base,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
-1,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
s_frag);
|
||||
}
|
||||
|
||||
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
|
||||
@@ -516,8 +505,6 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int *__restrict__ mask_offset,
|
||||
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
@@ -526,14 +513,12 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const int speculate_max_draft_token_num = 5,
|
||||
const int32_t attn_mask_len = -1) {
|
||||
const int speculate_max_draft_token_num = 5) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
constexpr uint32_t num_vecs_per_head_k =
|
||||
HEAD_DIM / 2 / num_elems_per_128b<CacheT>();
|
||||
@@ -556,11 +541,6 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -647,7 +627,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -723,9 +703,10 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const uint32_t mask_check_iteration =
|
||||
(CAUSAL ? (min(chunk_len,
|
||||
sub_if_greater_or_zero(
|
||||
kv_len - q_len,
|
||||
kv_len - q_len +
|
||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||
chunk_start)))
|
||||
: mask_offset ? 0 : chunk_len) /
|
||||
: chunk_len) /
|
||||
(NUM_WARP_KV * num_frags_z * 16);
|
||||
|
||||
uint32_t k_smem_offset_r =
|
||||
@@ -807,15 +788,12 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
num_frags_z>(q_base_seq_id_this_block,
|
||||
kv_idx_base + wid * num_frags_z * 16,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
attn_mask_len,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
s_frag);
|
||||
}
|
||||
|
||||
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
|
||||
@@ -1110,7 +1088,6 @@ void MultiQueryAppendC4Attention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1119,7 +1096,6 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1175,7 +1151,6 @@ void MultiQueryAppendC4Attention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1184,7 +1159,6 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
@@ -1311,18 +1285,10 @@ void MultiQueryAppendC4Attention(
|
||||
if (!is_decoder) {
|
||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||
}
|
||||
|
||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||
int32_t attn_mask_len;
|
||||
if (attn_mask) {
|
||||
attn_mask_len = attn_mask.get().shape()[1];
|
||||
} else {
|
||||
attn_mask_len = -1;
|
||||
}
|
||||
|
||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||
dim3 blocks(32, num_warps);
|
||||
if (num_chunks <= 0) {
|
||||
if (num_chunks <= 1) {
|
||||
auto nosplit_kv_kernel =
|
||||
multi_query_append_attention_c4_warp1_4_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
@@ -1368,9 +1334,6 @@ void MultiQueryAppendC4Attention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1379,13 +1342,11 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
speculate_max_draft_token_num);
|
||||
} else {
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
if (is_decoder) {
|
||||
@@ -1431,15 +1392,15 @@ void MultiQueryAppendC4Attention(
|
||||
const_cast<uint8_t *>(cache_v.data<uint8_t>()),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
|
||||
cache_k_zp ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
|
||||
cache_v_zp ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1449,9 +1410,6 @@ void MultiQueryAppendC4Attention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1460,13 +1418,11 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
speculate_max_draft_token_num);
|
||||
// merge
|
||||
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
||||
if (is_decoder) {
|
||||
@@ -1489,8 +1445,8 @@ void MultiQueryAppendC4Attention(
|
||||
seq_lens_encoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1507,14 +1463,14 @@ void MultiQueryAppendC4Attention(
|
||||
constexpr int blockx = HEAD_DIM / vec_size;
|
||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
||||
dim3 grids_merge(min(sm_count * 4, token_num),
|
||||
num_heads);
|
||||
num_heads);
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_v2_kernel<NV_TYPE,
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
<<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
@@ -1525,8 +1481,8 @@ void MultiQueryAppendC4Attention(
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
|
@@ -32,15 +32,14 @@ template <typename T,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
bool IsFP8=false>
|
||||
__global__ void multi_query_append_attention_c8_kernel(
|
||||
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
|
||||
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
|
||||
// head_dim]
|
||||
CacheT *__restrict__ cache_v,
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads]
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int *__restrict__ seq_lens,
|
||||
@@ -49,7 +48,6 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int *__restrict__ mask_offset,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
@@ -58,7 +56,6 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -88,40 +85,33 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
|
||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
|
||||
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
T cache_k_scale_reg[num_frags_y * 4];
|
||||
T cache_v_scale_reg[num_frags_y * 2];
|
||||
if (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
}
|
||||
|
||||
const uint32_t q_end =
|
||||
@@ -189,7 +179,6 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
} else {
|
||||
o_base_ptr_int8 = out + o_offset;
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -210,13 +199,6 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
|
||||
v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
|
||||
T* k_smem_scale = nullptr;
|
||||
T* v_smem_scale = nullptr;
|
||||
if constexpr (IsDynamicC8) {
|
||||
k_smem_scale = reinterpret_cast<T*>(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
|
||||
v_smem_scale = k_smem_scale + num_frags_z * 16;
|
||||
}
|
||||
|
||||
|
||||
const uint32_t num_iterations = div_up(
|
||||
@@ -234,7 +216,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
kv_len - q_len +
|
||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||
chunk_start)))
|
||||
: mask_offset ? 0 : chunk_len) /
|
||||
: chunk_len) /
|
||||
(num_frags_z * 16);
|
||||
|
||||
uint32_t k_smem_offset_r =
|
||||
@@ -298,22 +280,10 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
|
||||
#pragma unroll 1
|
||||
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
k_smem_scale,
|
||||
cache_k_scale_reg,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
// s = qk
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
|
||||
&qo_smem,
|
||||
&q_smem_offset_r,
|
||||
&k_smem,
|
||||
@@ -330,15 +300,12 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
num_frags_z>(q_base_seq_id_this_block,
|
||||
kv_idx_base,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
-1,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
s_frag);
|
||||
}
|
||||
|
||||
// update m,d
|
||||
@@ -346,7 +313,6 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
s_frag, o_frag, m_frag, d_frag);
|
||||
__syncthreads();
|
||||
|
||||
const int ori_kv_idx_base = kv_idx_base;
|
||||
kv_idx_base += num_frags_z * 16;
|
||||
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -365,18 +331,6 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
commit_group();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
v_smem_scale,
|
||||
cache_v_scale_reg,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
ori_kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
|
||||
@@ -387,9 +341,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
BLOCK_SIZE,
|
||||
T,
|
||||
CacheT,
|
||||
is_scale_channel_wise,
|
||||
IsFP8,
|
||||
IsDynamicC8>(
|
||||
is_scale_channel_wise, IsFP8>(
|
||||
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
|
||||
__syncthreads();
|
||||
|
||||
@@ -506,15 +458,14 @@ template <typename T,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool is_scale_channel_wise=false,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
bool IsFP8=false>
|
||||
__global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
|
||||
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
|
||||
// head_dim]
|
||||
CacheT *__restrict__ cache_v,
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int *__restrict__ seq_lens,
|
||||
@@ -523,8 +474,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const int *__restrict__ tile_ids_per_batch,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
|
||||
const int *__restrict__ mask_offset,
|
||||
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const int max_block_num_per_seq,
|
||||
@@ -533,14 +482,12 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
float *__restrict__ tmp_d, // [token_num, num_chunks, num_heads]
|
||||
OutT *__restrict__ out,
|
||||
const int speculate_max_draft_token_num = 5,
|
||||
const int32_t attn_mask_len = -1) {
|
||||
const int speculate_max_draft_token_num = 5) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
constexpr uint32_t num_vecs_per_head_k =
|
||||
HEAD_DIM / num_elems_per_128b<CacheT>();
|
||||
@@ -563,39 +510,32 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
}
|
||||
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
|
||||
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
T cache_k_scale_reg[num_frags_y * 4];
|
||||
T cache_v_scale_reg[num_frags_y * 2];
|
||||
if (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
}
|
||||
const uint32_t q_end =
|
||||
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
|
||||
@@ -661,7 +601,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
|
||||
|
||||
smem_t qo_smem(smem);
|
||||
|
||||
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
@@ -686,13 +626,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
|
||||
v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
|
||||
T* k_smem_scale = nullptr;
|
||||
T* v_smem_scale = nullptr;
|
||||
if constexpr (IsDynamicC8) {
|
||||
k_smem_scale = reinterpret_cast<T*>(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
|
||||
v_smem_scale = k_smem_scale + NUM_WARP_KV * num_frags_z * 16;
|
||||
}
|
||||
|
||||
const uint32_t num_iterations = div_up(
|
||||
CAUSAL
|
||||
@@ -709,7 +642,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
kv_len - q_len +
|
||||
tile_id * num_rows_per_block / GROUP_SIZE,
|
||||
chunk_start)))
|
||||
: mask_offset ? 0 : chunk_len) /
|
||||
: chunk_len) /
|
||||
(NUM_WARP_KV * num_frags_z * 16);
|
||||
|
||||
uint32_t k_smem_offset_r =
|
||||
@@ -775,23 +708,11 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
commit_group();
|
||||
#pragma unroll 1
|
||||
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
k_smem_scale,
|
||||
cache_k_scale_reg,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
|
||||
// s = qk
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
|
||||
&qo_smem,
|
||||
&q_smem_offset_r,
|
||||
&k_smem,
|
||||
@@ -807,16 +728,12 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
NUM_WARPS,
|
||||
num_frags_x,
|
||||
num_frags_y,
|
||||
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
|
||||
q_base_seq_id_this_block,
|
||||
num_frags_z>(q_base_seq_id_this_block,
|
||||
kv_idx_base + wid * num_frags_z * 16,
|
||||
q_len,
|
||||
kv_len,
|
||||
chunk_end,
|
||||
attn_mask_len,
|
||||
s_frag,
|
||||
mask_offset_this_seq);
|
||||
|
||||
s_frag);
|
||||
}
|
||||
|
||||
// update m,d
|
||||
@@ -824,7 +741,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
s_frag, o_frag, m_frag, d_frag);
|
||||
__syncthreads();
|
||||
|
||||
const uint32_t ori_kv_idx_base = kv_idx_base;
|
||||
kv_idx_base += NUM_WARP_KV * num_frags_z * 16;
|
||||
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -843,18 +759,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
commit_group();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
v_smem_scale,
|
||||
cache_v_scale_reg,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
ori_kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
|
||||
@@ -865,9 +769,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
BLOCK_SIZE,
|
||||
T,
|
||||
CacheT,
|
||||
is_scale_channel_wise,
|
||||
IsFP8,
|
||||
IsDynamicC8>(
|
||||
is_scale_channel_wise, IsFP8>(
|
||||
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
|
||||
__syncthreads();
|
||||
|
||||
@@ -981,8 +883,7 @@ template <typename T,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
bool IsFP8=false>
|
||||
void MultiQueryAppendC8Attention(
|
||||
const AppendAttnMetaData &meta_data,
|
||||
const paddle::Tensor &qkv,
|
||||
@@ -1040,8 +941,7 @@ void MultiQueryAppendC8Attention(
|
||||
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16;
|
||||
constexpr uint32_t smem_size =
|
||||
num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
|
||||
num_frags_z * 16 * sizeof(T) * 2;
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
|
||||
auto split_kv_kernel =
|
||||
multi_query_append_attention_c8_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
@@ -1058,9 +958,7 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
false, IsFP8>;
|
||||
if (is_scale_channel_wise) {
|
||||
split_kv_kernel =
|
||||
multi_query_append_attention_c8_kernel<NV_TYPE,
|
||||
@@ -1078,9 +976,7 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
true, IsFP8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(split_kv_kernel,
|
||||
@@ -1114,9 +1010,7 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
false, IsFP8>;
|
||||
if (is_scale_channel_wise) {
|
||||
nosplit_kv_kernel =
|
||||
multi_query_append_attention_c8_kernel<NV_TYPE,
|
||||
@@ -1134,9 +1028,7 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
true, IsFP8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(nosplit_kv_kernel,
|
||||
@@ -1162,7 +1054,6 @@ void MultiQueryAppendC8Attention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1171,7 +1062,6 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1221,7 +1111,6 @@ void MultiQueryAppendC8Attention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1230,7 +1119,6 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
@@ -1316,8 +1204,7 @@ void MultiQueryAppendC8Attention(
|
||||
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2;
|
||||
constexpr uint32_t smem_size =
|
||||
num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
|
||||
NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2;
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
|
||||
auto split_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
@@ -1334,9 +1221,7 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
false, IsFP8>;
|
||||
if (is_scale_channel_wise) {
|
||||
split_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
@@ -1354,9 +1239,7 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
true, IsFP8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(split_kv_kernel,
|
||||
@@ -1371,17 +1254,10 @@ void MultiQueryAppendC8Attention(
|
||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||
}
|
||||
|
||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||
int32_t attn_mask_len;
|
||||
if (attn_mask) {
|
||||
attn_mask_len = attn_mask.get().shape()[1];
|
||||
} else {
|
||||
attn_mask_len = -1;
|
||||
}
|
||||
|
||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||
dim3 blocks(32, num_warps);
|
||||
if (num_chunks <= 0) {
|
||||
if (num_chunks <= 1) {
|
||||
auto nosplit_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
@@ -1398,9 +1274,7 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
false, IsFP8>;
|
||||
if (is_scale_channel_wise) {
|
||||
nosplit_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
@@ -1418,9 +1292,7 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
true, IsFP8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(nosplit_kv_kernel,
|
||||
@@ -1446,9 +1318,6 @@ void MultiQueryAppendC8Attention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1457,13 +1326,11 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
speculate_max_draft_token_num);
|
||||
} else {
|
||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||
if (is_decoder) {
|
||||
@@ -1510,8 +1377,8 @@ void MultiQueryAppendC8Attention(
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
|
||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1521,9 +1388,6 @@ void MultiQueryAppendC8Attention(
|
||||
tile_ids_per_batch.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
meta_data.mask_offset,
|
||||
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
max_block_num_per_seq,
|
||||
@@ -1532,13 +1396,11 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||
speculate_max_draft_token_num,
|
||||
attn_mask_len);
|
||||
speculate_max_draft_token_num);
|
||||
// merge
|
||||
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
||||
if (is_decoder) {
|
||||
@@ -1556,8 +1418,8 @@ void MultiQueryAppendC8Attention(
|
||||
seq_lens_encoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1574,14 +1436,14 @@ void MultiQueryAppendC8Attention(
|
||||
constexpr int blockx = HEAD_DIM / vec_size;
|
||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
||||
dim3 grids_merge(min(sm_count * 4, token_num),
|
||||
num_heads);
|
||||
num_heads);
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_v2_kernel<NV_TYPE,
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
vec_size,
|
||||
blocky,
|
||||
HEAD_DIM,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL>
|
||||
<<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
@@ -1592,8 +1454,8 @@ void MultiQueryAppendC8Attention(
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T *>(shift_bias.get().data<T>()))
|
||||
: nullptr,
|
||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||
smooth_weight.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -1655,7 +1517,6 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out) {
|
||||
const auto token_num = meta_data.token_nums;
|
||||
@@ -1664,7 +1525,6 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const auto num_heads = meta_data.q_num_heads;
|
||||
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
|
||||
const auto head_dim = meta_data.head_dims;
|
||||
bool is_dynamic_cfp8 = cache_quant_type_str == "block_wise_fp8";
|
||||
|
||||
DISPATCH_CAUSAL(
|
||||
causal,
|
||||
@@ -1683,46 +1543,43 @@ void CascadeAppendAttentionC8Kernel(
|
||||
BLOCK_SIZE,
|
||||
{DISPATCH_BLOCKSHAPE_Q(
|
||||
block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, {
|
||||
DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, {
|
||||
MultiQueryAppendC8Attention<T,
|
||||
GROUP_SIZE,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
CAUSAL,
|
||||
BLOCK_SHAPE_Q,
|
||||
NUM_WARP_Q,
|
||||
OutT,
|
||||
ENABLE_PREFILL,
|
||||
IsFP8,
|
||||
IsDynamicC8>(
|
||||
meta_data,
|
||||
qkv,
|
||||
cache_k,
|
||||
cache_v,
|
||||
attn_mask,
|
||||
cache_k_scale.get(),
|
||||
cache_v_scale.get(),
|
||||
shift_bias,
|
||||
smooth_weight,
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
is_decoder,
|
||||
stream,
|
||||
out);
|
||||
})})})})})})})
|
||||
MultiQueryAppendC8Attention<T,
|
||||
GROUP_SIZE,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
CAUSAL,
|
||||
BLOCK_SHAPE_Q,
|
||||
NUM_WARP_Q,
|
||||
OutT,
|
||||
ENABLE_PREFILL, IsFP8>(
|
||||
meta_data,
|
||||
qkv,
|
||||
cache_k,
|
||||
cache_v,
|
||||
attn_mask,
|
||||
cache_k_scale.get(),
|
||||
cache_v_scale.get(),
|
||||
shift_bias,
|
||||
smooth_weight,
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
is_decoder,
|
||||
stream,
|
||||
out);
|
||||
})})})})})})
|
||||
}
|
||||
|
@@ -384,113 +384,6 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_k_dynamic_scale(
|
||||
T* k_smem_scale,
|
||||
T* cache_k_reg,
|
||||
const int* block_table_now,
|
||||
const T* cache_k_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end
|
||||
) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
if (tid < block_size) {
|
||||
k_smem_scale[tid] = cache_k_scale_now[tid];
|
||||
}
|
||||
__syncthreads();
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8];
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
|
||||
if (kv_idx_this_thread < chunk_end) {
|
||||
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
|
||||
} else {
|
||||
k_smem_scale[ty * 32 + tx] = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_v_dynamic_scale(
|
||||
T* v_smem_scale,
|
||||
T* cache_v_reg,
|
||||
const int* block_table_now,
|
||||
const T* cache_v_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end
|
||||
) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
if (tid < block_size) {
|
||||
v_smem_scale[tid] = cache_v_scale_now[tid];
|
||||
}
|
||||
__syncthreads();
|
||||
const uint32_t row_id = tx % 4 * 2;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9];
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
|
||||
if (kv_idx_this_thread < chunk_end) {
|
||||
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
|
||||
} else {
|
||||
v_smem_scale[ty * 32 + tx] = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
const uint32_t row_id = tx % 4 * 2;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <SharedMemFillMode fill_mode,
|
||||
uint32_t num_warps,
|
||||
uint32_t block_size,
|
||||
@@ -923,8 +816,7 @@ template <uint32_t num_frags_x,
|
||||
typename T,
|
||||
typename CacheT,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
bool IsFP8=false>
|
||||
__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
||||
uint32_t* q_smem_offset_r,
|
||||
smem_t* k_smem,
|
||||
@@ -968,27 +860,20 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
const int scale_col = (ky * 2 + fy) * 4;
|
||||
b_frag_dq_T[0] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_k_scale[0];
|
||||
}
|
||||
}
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
const int scale_col = (ky * 2 + fy) * 4;
|
||||
b_frag_dq_T[0] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4];
|
||||
b_frag_dq_T[b_i] *= cache_k_scale[0];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
@@ -1020,15 +905,12 @@ template <typename T,
|
||||
uint32_t num_frags_y,
|
||||
uint32_t num_frags_z,
|
||||
bool IS_SYSTEM = false>
|
||||
__device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
const uint32_t qo_idx_base,
|
||||
__device__ __forceinline__ void mask_s(const uint32_t qo_idx_base,
|
||||
const uint32_t kv_idx_base,
|
||||
const uint32_t qo_len,
|
||||
const uint32_t kv_len,
|
||||
const uint32_t chunk_end,
|
||||
const int32_t attn_mask_len,
|
||||
float (*s_frag)[num_frags_z][8],
|
||||
const int *mask_offset = nullptr) {
|
||||
float (*s_frag)[num_frags_z][8]) {
|
||||
const uint32_t tx = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
@@ -1042,21 +924,10 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
group_size,
|
||||
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
|
||||
8 * (reg_id / 4) + reg_id % 2;
|
||||
bool out_of_boundary;
|
||||
if (mask_offset) {
|
||||
out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true;
|
||||
} else {
|
||||
out_of_boundary =
|
||||
(causal
|
||||
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && attn_mask_len > 0 && q_idx < static_cast<uint32_t>(attn_mask_len)) {
|
||||
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
|
||||
bool mask = attn_mask[mask_idx];
|
||||
out_of_boundary |= mask;
|
||||
}
|
||||
}
|
||||
|
||||
const bool out_of_boundary =
|
||||
(causal
|
||||
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
||||
: kv_idx >= chunk_end);
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
s_frag[fx][fz][reg_id] =
|
||||
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];
|
||||
@@ -1064,7 +935,6 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
||||
s_frag[fx][fz][reg_id] =
|
||||
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
|
||||
}
|
||||
// printf("tid: %d. qk[%u,%u] = %f, mask: %d \n ", threadIdx.x, kv_idx, q_idx, static_cast<float>(s_frag[fx][fz][reg_id]), int(out_of_boundary));
|
||||
} else {
|
||||
const uint32_t q_idx = qo_idx_base,
|
||||
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
|
||||
@@ -1208,9 +1078,7 @@ template <uint32_t num_frags_x,
|
||||
uint32_t block_size,
|
||||
typename T,
|
||||
typename CacheT,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
bool is_scale_channel_wise = false, bool IsFP8=false>
|
||||
__device__ __forceinline__ void compute_sfm_v_c8(
|
||||
smem_t* v_smem,
|
||||
uint32_t* v_smem_offset_r,
|
||||
@@ -1252,28 +1120,16 @@ __device__ __forceinline__ void compute_sfm_v_c8(
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
}
|
||||
} else {
|
||||
const int scale_col = (kz * 2 + fz) * 4;
|
||||
b_frag_dq_T[0] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
|
||||
@@ -1300,9 +1156,7 @@ template <uint32_t num_frags_x,
|
||||
uint32_t block_size,
|
||||
typename T,
|
||||
typename CacheT,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
bool is_scale_channel_wise = false, bool IsFP8=false>
|
||||
__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
smem_t* v_smem,
|
||||
uint32_t* v_smem_offset_r,
|
||||
@@ -1346,28 +1200,16 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
}
|
||||
} else {
|
||||
const int scale_col = (kz * 2 + fz) * 4;
|
||||
b_frag_dq_T[0] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
|
||||
|
@@ -103,7 +103,6 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -265,10 +264,9 @@ void CascadeAppendAttentionKernel(
|
||||
causal,
|
||||
is_decoder,
|
||||
enable_prefill,
|
||||
cache_quant_type_str,
|
||||
stream,
|
||||
out);
|
||||
} else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
|
||||
qkv,
|
||||
cache_k,
|
||||
@@ -301,7 +299,6 @@ void CascadeAppendAttentionKernel(
|
||||
causal,
|
||||
is_decoder,
|
||||
enable_prefill,
|
||||
cache_quant_type_str,
|
||||
stream,
|
||||
out);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -15,73 +15,13 @@
|
||||
#include "decoder_write_cache_with_rope_kernel.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
T* key_cache,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
const float* qkv_out_scales,
|
||||
const T* qkv_biases,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
const uint32_t elem_nums =
|
||||
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||
: bsz * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||
constexpr int HEAD_DIM = 128;
|
||||
|
||||
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
|
||||
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize>
|
||||
<<<grid_size, block_dim, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
}
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
T* key_cache,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
@@ -94,7 +34,6 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int rotary_dim,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const cudaStream_t& stream,
|
||||
@@ -118,6 +57,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -131,37 +71,15 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
} else {
|
||||
if (rotary_dim < dim_head){
|
||||
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}else{
|
||||
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
||||
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -173,9 +91,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
kv_num_heads);
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
@@ -186,6 +102,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -208,6 +125,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -231,6 +149,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
@@ -263,6 +182,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -278,8 +198,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
} else {
|
||||
append_decode_cache_int8_neox_rope_kernel<T, 4>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
@@ -288,6 +207,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -301,8 +221,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
@@ -313,6 +232,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -328,8 +248,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
} else {
|
||||
append_decode_cache_int8_rope_kernel<T, 4, 0, 128, is_scale_channel_wise, IsFP8>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
@@ -338,6 +257,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -351,8 +271,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -363,6 +282,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
@@ -397,6 +317,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -414,8 +335,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
} else {
|
||||
append_decode_cache_int4_neox_rope_kernel<T, 4>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
@@ -424,6 +344,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -439,8 +360,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
@@ -451,6 +371,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -468,8 +389,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
} else {
|
||||
append_decode_cache_int4_rope_kernel<T, 4>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
@@ -478,6 +398,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
@@ -493,8 +414,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -504,6 +424,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
@@ -520,10 +441,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
paddle::Tensor* value_cache_out) {
|
||||
typedef cascade_attn_type_traits<T> traits_;
|
||||
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
|
||||
typedef typename traits_::type DataType_;
|
||||
@@ -540,30 +458,85 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const float* cos_emb =
|
||||
rotary_embs ? rotary_embs.get().data<float>() : nullptr;
|
||||
const float* sin_emb;
|
||||
int rotary_dim = dim_head;
|
||||
if (rotary_embs) {
|
||||
sin_emb =
|
||||
use_neox_rotary_style
|
||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
|
||||
if(rotary_dim < dim_head){
|
||||
if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight || k_norm_weight|| cache_quant_type_str != "none"){
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"partial_rotary_factor < 1.0 only supports neox_rotary_style=True, qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and cache_quant_type_str is 'none'."));
|
||||
}
|
||||
sin_emb = rotary_embs.get().data<float>() + max_seq_len * rotary_dim / 2;
|
||||
}
|
||||
}
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_decode_cache_rope_qk_norm(
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_decode_cache_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
bool is_scale_channel_wise = false;
|
||||
if (cache_k_scale && cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) {
|
||||
is_scale_channel_wise = true;
|
||||
}
|
||||
if (is_scale_channel_wise) {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
@@ -573,6 +546,12 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
@@ -582,247 +561,84 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||
rms_norm_eps);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
dim3 grids(bsz, all_warps / num_warps);
|
||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
|
||||
q_norm_weight.get().data<float>(),
|
||||
k_norm_weight.get().data<float>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm just supports cache_quant_type none/block_wise_fp8");
|
||||
rope_3d);
|
||||
}
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_decode_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_decode_cache_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
bool is_scale_channel_wise = false;
|
||||
if (cache_k_scale && cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) {
|
||||
is_scale_channel_wise = true;
|
||||
}
|
||||
if (is_scale_channel_wise) {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
}
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
dim3 grids(bsz, all_warps / num_warps);
|
||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
|
||||
nullptr,
|
||||
nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_decode_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, cache_fp8 "
|
||||
"cache_int4_zp]");
|
||||
}
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, cache_fp8 "
|
||||
"cache_int4_zp]");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -834,6 +650,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
@@ -850,10 +667,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
||||
template void
|
||||
DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
@@ -863,6 +677,7 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
@@ -879,10 +694,7 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
||||
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -891,6 +703,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
@@ -907,10 +720,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
||||
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -919,6 +729,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
@@ -935,7 +746,4 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
@@ -23,6 +23,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
// kv_num_heads, head_dim] if GQA)
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||
@@ -39,6 +40,4 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight, const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -46,32 +46,38 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
paddle::Tensor* value_cache_out) {
|
||||
auto token_num = meta_data.token_nums;
|
||||
auto num_heads = meta_data.q_num_heads;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
auto head_dim = meta_data.head_dims;
|
||||
bool is_scale_channel_wise = false;
|
||||
int rotary_dim = head_dim;
|
||||
if (cache_k_scale && cache_k_scale.get().dims()[0] == head_dim * kv_num_heads) {
|
||||
is_scale_channel_wise = true;
|
||||
}
|
||||
if (rotary_embs){
|
||||
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
|
||||
if(rotary_dim < head_dim){
|
||||
if (!use_neox_style || q_norm_weight || k_norm_weight || num_heads == kv_num_heads || is_scale_channel_wise){
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"partial_rotary_factor < 1.0 only supports use_neox_rotary_style=True, q_norm_weight/k_norm_weight) is None, GQA and is_scale_channel_wise=false."));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) {
|
||||
gqa_rotary_qk_norm_variable(
|
||||
if (num_heads == kv_num_heads) {
|
||||
rotary_qk_variable(
|
||||
qkv_out->data<T>(),
|
||||
qkv.data<QKV_TYPE>(),
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
|
||||
rotary_embs.get().data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
max_seq_len,
|
||||
rotary_embs.get().dims()[2],
|
||||
head_dim,
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
if (!is_scale_channel_wise) {
|
||||
gqa_rotary_qk_variable(
|
||||
qkv_out->data<T>(),
|
||||
qkv.data<QKV_TYPE>(),
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
@@ -89,81 +95,31 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
head_dim,
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d,
|
||||
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||
rms_norm_eps);
|
||||
rope_3d);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"gqa_rotary_qk_norm_variable only support gqa mode. channel wise scale and neox style are not supported");
|
||||
gqa_rotary_qk_quant_variable(
|
||||
qkv_out->data<T>(),
|
||||
qkv.data<QKV_TYPE>(),
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
|
||||
cache_k_scale ? cache_k_scale.get().data<T>() : nullptr,
|
||||
cache_v_scale ? cache_v_scale.get().data<T>() : nullptr,
|
||||
rotary_embs.get().data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
rotary_embs.get().dims()[2],
|
||||
head_dim,
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d);
|
||||
}
|
||||
} else {
|
||||
if (num_heads == kv_num_heads) {
|
||||
rotary_qk_variable(
|
||||
qkv_out->data<T>(),
|
||||
qkv.data<QKV_TYPE>(),
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
|
||||
rotary_embs.get().data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
max_seq_len,
|
||||
rotary_embs.get().dims()[2],
|
||||
head_dim,
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
if (!is_scale_channel_wise) {
|
||||
gqa_rotary_qk_variable(
|
||||
qkv_out->data<T>(),
|
||||
qkv.data<QKV_TYPE>(),
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
|
||||
rotary_embs.get().data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2],
|
||||
head_dim,
|
||||
rotary_dim,
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
gqa_rotary_qk_quant_variable(
|
||||
qkv_out->data<T>(),
|
||||
qkv.data<QKV_TYPE>(),
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
|
||||
cache_k_scale ? cache_k_scale.get().data<T>() : nullptr,
|
||||
cache_v_scale ? cache_v_scale.get().data<T>() : nullptr,
|
||||
rotary_embs.get().data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
rotary_embs.get().dims()[2],
|
||||
head_dim,
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
const uint32_t block_size = meta_data.block_size;
|
||||
if (cache_quant_type_str == "none") {
|
||||
@@ -178,7 +134,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
stream,
|
||||
key_cache_out,
|
||||
value_cache_out);
|
||||
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
|
||||
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8") {
|
||||
DISPATCH_HEAD_DIM(
|
||||
head_dim, HEAD_DIM, {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {
|
||||
CascadeAppendWriteCacheKVC8QKV<T, HEAD_DIM, BLOCK_SIZE>(
|
||||
@@ -198,7 +154,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
is_scale_channel_wise,
|
||||
cache_quant_type_str,
|
||||
cache_quant_type_str == "cache_fp8",
|
||||
stream,
|
||||
key_cache_out,
|
||||
value_cache_out);
|
||||
|
@@ -11,11 +11,10 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/memory/memcpy.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void
|
||||
@@ -117,93 +116,6 @@ void GetMaxLen(const paddle::Tensor &seq_lens_tensor,
|
||||
max_len_tensor.data<int>(), batch_size);
|
||||
}
|
||||
|
||||
template <uint32_t config_size>
|
||||
__global__ void search_chunk_size_for_mla(
|
||||
const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
int *__restrict__ num_blocks_x,
|
||||
int *__restrict__ res_chunk_size,
|
||||
const int bsz,
|
||||
const int set_chunk_size,
|
||||
const int block_size,
|
||||
const int sm_cout) {
|
||||
const uint32_t conf_id = threadIdx.x;
|
||||
int gridx = 0;
|
||||
if (set_chunk_size > 0 && conf_id == 0) {
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
int seq_len_encoder = seq_lens_encoder[bid];
|
||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
||||
if (seq_len == 0 || seq_len_encoder > 0) continue;
|
||||
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, set_chunk_size);
|
||||
gridx += loop_times;
|
||||
}
|
||||
*num_blocks_x = gridx;
|
||||
*res_chunk_size = set_chunk_size;
|
||||
} else if (conf_id < config_size) {
|
||||
__shared__ int gridx_shared[config_size];
|
||||
// chunk_size is a multiple of 64
|
||||
const int chunk_size = block_size << conf_id;
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
int seq_len_encoder = seq_lens_encoder[bid];
|
||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
||||
if (seq_len == 0 || seq_len_encoder > 0) continue;
|
||||
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
|
||||
gridx += loop_times;
|
||||
}
|
||||
gridx_shared[conf_id] = gridx;
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
uint32_t res_id = 0;
|
||||
uint32_t max_last_wave_block = 0;
|
||||
for (uint32_t i = 1; i < config_size; i++) {
|
||||
uint32_t last_wave_block = gridx_shared[i] % sm_cout;
|
||||
if (last_wave_block >= max_last_wave_block) {
|
||||
res_id = i;
|
||||
max_last_wave_block = last_wave_block;
|
||||
}
|
||||
}
|
||||
*num_blocks_x = gridx_shared[res_id];
|
||||
*res_chunk_size = block_size << res_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void split_block_for_mla(const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
int *__restrict__ batch_ids,
|
||||
int *__restrict__ tile_ids_per_batch,
|
||||
const int bsz,
|
||||
const int chunk_size) {
|
||||
if (threadIdx.x == 0) {
|
||||
int index = 0;
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
int seq_len_encoder = seq_lens_encoder[bid];
|
||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
||||
|
||||
if (seq_len == 0) continue;
|
||||
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
|
||||
if (seq_len_encoder > 0) {
|
||||
loop_times = 0;
|
||||
}
|
||||
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
|
||||
batch_ids[index] = bid;
|
||||
tile_ids_per_batch[index++] = tile_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void split_q_block(const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
int *__restrict__ batch_ids,
|
||||
@@ -279,38 +191,26 @@ get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
|
||||
}
|
||||
}
|
||||
|
||||
void GetBlockShapeAndSplitKVBlock(
|
||||
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_device, // Inplace
|
||||
paddle::Tensor &decoder_chunk_size_device, // Inplace
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU
|
||||
paddle::Tensor &kv_batch_ids, // Inplace
|
||||
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
|
||||
paddle::Tensor &max_len_kv_cpu, // Inplace, CPU
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
const int block_size,
|
||||
const int decoder_step_token_num)
|
||||
{
|
||||
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
||||
const int group_size, const int block_size,
|
||||
const int decoder_step_token_num) {
|
||||
auto stream = seq_lens_encoder.stream();
|
||||
int bsz = seq_lens_this_time.shape()[0];
|
||||
|
||||
paddle::Tensor max_len_tensor_gpu = GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, paddle::DataType::INT32, seq_lens_this_time.place());
|
||||
auto max_len_tensor =
|
||||
GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
|
||||
max_len_tensor_gpu, bsz);
|
||||
max_len_tensor_cpu.copy_(max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
|
||||
max_len_tensor, bsz);
|
||||
|
||||
auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
|
||||
// max_len_this_time, max_enc_len_this_time, max_dec_len_this_time,
|
||||
// max_enc_dec_len_this_time, max_just_dec_len_this_time,
|
||||
// max_just_dec_merged_len_this_time, max_system_len,
|
||||
// max_just_dec_len_without_system
|
||||
auto max_len_cpu = max_len_tensor.copy_to(paddle::CPUPlace(), false);
|
||||
auto max_len_cpu_ptr = max_len_cpu.data<int>();
|
||||
int max_len_this_time = max_len_cpu_ptr[0];
|
||||
int max_enc_len_this_time = max_len_cpu_ptr[1];
|
||||
int max_dec_len_this_time = max_len_cpu_ptr[2];
|
||||
@@ -320,120 +220,34 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
int max_system_len = max_len_cpu_ptr[6];
|
||||
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
|
||||
|
||||
paddle::Tensor encoder_batch_ids;
|
||||
paddle::Tensor encoder_tile_ids_per_batch;
|
||||
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor kv_batch_ids;
|
||||
paddle::Tensor kv_tile_ids_per_batch;
|
||||
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor decoder_batch_ids;
|
||||
paddle::Tensor decoder_tile_ids_per_batch;
|
||||
paddle::Tensor decoder_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor max_len_kv_cpu; /*cpu*/
|
||||
|
||||
auto max_len_kv =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
|
||||
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
|
||||
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(), bsz);
|
||||
|
||||
max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);
|
||||
|
||||
// decoder
|
||||
if (max_dec_len_this_time > 0) {
|
||||
const bool mla_use_tensorcore = GetMlaUseTensorcore();
|
||||
if (mla_use_tensorcore && group_size <= 64) {
|
||||
const int set_chunk_size = get_mla_dec_chunk_size(bsz);
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int sm_cout;
|
||||
cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device);
|
||||
constexpr int config_size =
|
||||
12; // search space for chunk size:[64, 128, 256, ... 131072]
|
||||
|
||||
search_chunk_size_for_mla<config_size>
|
||||
<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
decoder_num_blocks_device.data<int>(),
|
||||
decoder_chunk_size_device.data<int>(),
|
||||
bsz,
|
||||
set_chunk_size,
|
||||
block_size,
|
||||
sm_cout);
|
||||
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
auto decoder_chunk_size_cpu =
|
||||
decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false);
|
||||
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
|
||||
|
||||
// NOTE: (changwenbin) When using auto_chunk,
|
||||
// decode_max_tile_size must take into account the maximum case, where * 1024 can cover 128K.
|
||||
// const uint32_t decoder_batch_shape = seq_lens_decoder.dims()[0] * 1024;
|
||||
|
||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape =
|
||||
bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_batch_ids.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
stream));
|
||||
|
||||
|
||||
split_block_for_mla<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
bsz,
|
||||
chunk_size);
|
||||
|
||||
} else {
|
||||
// Note:(changwenbin)In order to adapt to cudagraph, the maximum value should be taken here
|
||||
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape = bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
split_q_block<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_device.data<int>(),
|
||||
bsz,
|
||||
decoder_block_shape_q,
|
||||
group_size);
|
||||
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
}
|
||||
} else {
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
}
|
||||
|
||||
// encoder
|
||||
if (max_enc_len_this_time > 0) {
|
||||
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
|
||||
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
|
||||
const uint32_t max_tile_size_per_bs_kv =
|
||||
div_up(max_enc_dec_len_this_time, block_size);
|
||||
kv_batch_ids =
|
||||
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
|
||||
seq_lens_encoder.place());
|
||||
kv_tile_ids_per_batch =
|
||||
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
|
||||
seq_lens_encoder.place());
|
||||
auto kv_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
|
||||
@@ -444,12 +258,16 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
|
||||
block_size, block_size);
|
||||
|
||||
kv_num_blocks_x_cpu.copy_(kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false);
|
||||
// Clear buffer
|
||||
const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
||||
const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
|
||||
kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
const uint32_t encoder_max_tile_size_per_bs_q =
|
||||
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
||||
encoder_batch_ids =
|
||||
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
encoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
auto encoder_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
|
||||
@@ -457,38 +275,108 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
encoder_tile_ids_per_batch.data<int>(),
|
||||
encoder_num_blocks_x.data<int>(), bsz,
|
||||
encoder_block_shape_q, group_size);
|
||||
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
|
||||
encoder_num_blocks_x_cpu =
|
||||
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||
} else {
|
||||
encoder_batch_ids =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
encoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
encoder_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||
kv_batch_ids =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
kv_tile_ids_per_batch =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
kv_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
}
|
||||
if (max_just_dec_len_this_time > 0) {
|
||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
|
||||
decoder_batch_ids =
|
||||
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
decoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
auto decoder_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
split_q_block<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(), seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(), decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_x.data<int>(), bsz, decoder_block_shape_q,
|
||||
group_size);
|
||||
decoder_num_blocks_x_cpu =
|
||||
decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||
} else {
|
||||
decoder_batch_ids =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
decoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
decoder_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||
}
|
||||
|
||||
return {encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks_x_cpu, /*cpu*/
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks_x_cpu, /*cpu*/
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks_x_cpu, /*cpu*/
|
||||
max_len_kv_cpu /*cpu*/,
|
||||
max_len_cpu};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
|
||||
const paddle::DataType &seq_lens_encoder_dtype,
|
||||
const paddle::DataType &seq_lens_decoder_dtype,
|
||||
const paddle::DataType &seq_lens_this_time_dtype) {
|
||||
return {
|
||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||
paddle::DataType::INT32, paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
|
||||
const std::vector<int64_t> &seq_lens_encoder_shape,
|
||||
const std::vector<int64_t> &seq_lens_decoder_shape,
|
||||
const std::vector<int64_t> &seq_lens_this_time_shape) {
|
||||
std::vector<int64_t> dynamic_shape = {-1};
|
||||
|
||||
return {dynamic_shape,
|
||||
dynamic_shape,
|
||||
{1},
|
||||
dynamic_shape,
|
||||
dynamic_shape,
|
||||
{1},
|
||||
dynamic_shape,
|
||||
dynamic_shape,
|
||||
{1},
|
||||
{1},
|
||||
{8}};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
.Inputs({
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks_cpu",
|
||||
"decoder_num_blocks_device",
|
||||
"decoder_chunk_size_device",
|
||||
"max_len_tensor_cpu",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
"encoder_num_blocks_x_cpu",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
"kv_num_blocks_x_cpu",
|
||||
"max_len_kv_cpu"
|
||||
})
|
||||
.Outputs({
|
||||
|
||||
})
|
||||
.Attrs({
|
||||
"encoder_block_shape_q: int",
|
||||
"decoder_block_shape_q: int",
|
||||
"group_size: int",
|
||||
"block_size: int",
|
||||
"decoder_step_token_num: int"
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock));
|
||||
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"})
|
||||
.Outputs({paddle::Optional("encoder_batch_ids"),
|
||||
paddle::Optional("encoder_tile_ids_per_batch"),
|
||||
paddle::Optional("encoder_num_blocks"),
|
||||
paddle::Optional("kv_batch_ids"),
|
||||
paddle::Optional("kv_tile_ids_per_batch"),
|
||||
paddle::Optional("kv_num_blocks"),
|
||||
paddle::Optional("decoder_batch_ids"),
|
||||
paddle::Optional("decoder_tile_ids_per_batch"),
|
||||
paddle::Optional("decoder_num_blocks"),
|
||||
paddle::Optional("max_len_kv"), "set_max_lengths"})
|
||||
.Attrs({"encoder_block_shape_q: int", "decoder_block_shape_q: int",
|
||||
"group_size: int", "block_size: int",
|
||||
"decoder_step_token_num: int"})
|
||||
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype));
|
||||
|
@@ -37,8 +37,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
const int q_num_head,
|
||||
const int kv_num_head,
|
||||
const int seq_len,
|
||||
const int last_dim,
|
||||
const bool rope_3d) {
|
||||
const int last_dim) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
@@ -63,7 +62,6 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
|
||||
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
|
||||
const int64_t base_idx =
|
||||
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
|
||||
h_bias;
|
||||
@@ -82,8 +80,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
Load<T, VecSize>(&qkv[base_idx], &src_vec);
|
||||
// do rope
|
||||
if (hi < q_num_head + kv_num_head) {
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
const float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
@@ -120,7 +118,6 @@ void gqa_rotary_qk_split_variable(
|
||||
const int seq_len,
|
||||
const int input_output_len,
|
||||
const int dim_head,
|
||||
const bool rope_3d,
|
||||
const cudaStream_t &stream) {
|
||||
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||
constexpr int PackSize = 16 / sizeof(T);
|
||||
@@ -149,8 +146,7 @@ void gqa_rotary_qk_split_variable(
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d);
|
||||
dim_head);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
@@ -217,7 +213,7 @@ __global__ void append_cache_kv_c16(
|
||||
|
||||
// load k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||
k_smem_offset_w =
|
||||
@@ -235,7 +231,7 @@ __global__ void append_cache_kv_c16(
|
||||
// deal k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
|
||||
uint32_t col_idx = fy * 16 + tid % 4 * 2;
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
|
||||
// layout
|
||||
@@ -278,7 +274,7 @@ __global__ void append_cache_kv_c16(
|
||||
|
||||
// load v_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 4; fz++) { // // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||
v_smem_offset_w =
|
||||
@@ -296,7 +292,7 @@ __global__ void append_cache_kv_c16(
|
||||
// deal v_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter
|
||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
|
||||
uint32_t col_idx = fy * 16 + tid % 4 * 2;
|
||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag);
|
||||
// layout
|
||||
@@ -400,7 +396,7 @@ __global__ void append_cache_kv_c8(
|
||||
|
||||
// load v_smem 64 rows, 128 cols
|
||||
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 once, need 1 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 noce, need 1 iter
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||
k_smem_offset_w =
|
||||
@@ -418,7 +414,7 @@ __global__ void append_cache_kv_c8(
|
||||
// deal k_smem 64 rows, 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 once, need 4 iter
|
||||
for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 noce, need 4 iter
|
||||
uint32_t col_idx = fy * 32 + tid % 4 * 2;
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
|
||||
// layout
|
||||
@@ -466,7 +462,7 @@ __global__ void append_cache_kv_c8(
|
||||
tid % 4 * num_elems_per_128b<CacheT>();
|
||||
// load v_smem 128 rows 64 cols
|
||||
for (int fy = 0; fy < 4; fy++) { // 8 rows pre warp once, 32 rows all 4 warps once, need 4 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 once, need 1 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 noce, need 1 iter
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||
v_smem_offset_w =
|
||||
@@ -485,7 +481,7 @@ __global__ void append_cache_kv_c8(
|
||||
// deal v_smem 128 rows 64 cols
|
||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
|
||||
for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 once, need 2 iter
|
||||
for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 noce, need 2 iter
|
||||
uint32_t kv_idx = fz * 32 + tid % 4 * 2;
|
||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
|
||||
// layout
|
||||
@@ -590,9 +586,9 @@ __global__ void append_cache_kv_c4(
|
||||
#pragma unroll
|
||||
for (uint32_t i = wid * 32 + tid; i < HEAD_DIM; i += 128) {
|
||||
cache_k_scale_smem[i] = cache_k_scale_now[i];
|
||||
cache_k_zero_point_smem[i] = cache_k_zp_now[i] + static_cast<T>(136.f);
|
||||
cache_k_zero_point_smem[i] = cache_k_zp_now[i] - static_cast<T>(136.f);
|
||||
cache_v_scale_smem[i] = cache_v_scale_now[i];
|
||||
cache_v_zero_point_smem[i] = cache_v_zp_now[i] + static_cast<T>(136.f);
|
||||
cache_v_zero_point_smem[i] = cache_v_zp_now[i] - static_cast<T>(136.f);
|
||||
}
|
||||
|
||||
smem_t k_smem(smem);
|
||||
@@ -614,7 +610,7 @@ __global__ void append_cache_kv_c4(
|
||||
|
||||
// load k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 2; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 once, need 1 iter
|
||||
for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 noce, need 1 iter
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||
k_smem_offset_w =
|
||||
@@ -632,7 +628,7 @@ __global__ void append_cache_kv_c4(
|
||||
// deal k_smem 64 rows 128 cols
|
||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||
uint32_t row_idx = wid * 16 + tid / 4;
|
||||
for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 once, need 2 iter
|
||||
for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 noce, need 2 iter
|
||||
uint32_t col_idx = fy * 64 + tid % 4 * 2;
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
|
||||
|
||||
@@ -644,25 +640,25 @@ __global__ void append_cache_kv_c4(
|
||||
convert_int4(frag_dq_T + 8, k_frag[2 * i + 1]);
|
||||
|
||||
if (row_idx < end_idx) {
|
||||
k_tile_ptr0[0] = (frag_dq_T[0] - cache_k_zero_point_smem[col_idx]) * cache_k_scale_smem[col_idx];
|
||||
k_tile_ptr0[1] = (frag_dq_T[1] - cache_k_zero_point_smem[col_idx + 1]) * cache_k_scale_smem[col_idx + 1];
|
||||
k_tile_ptr0[8] = (frag_dq_T[2] - cache_k_zero_point_smem[col_idx + 8]) * cache_k_scale_smem[col_idx + 8];
|
||||
k_tile_ptr0[9] = (frag_dq_T[3] - cache_k_zero_point_smem[col_idx + 9]) * cache_k_scale_smem[col_idx + 9];
|
||||
k_tile_ptr0[16] = (frag_dq_T[8] - cache_k_zero_point_smem[col_idx + 16]) * cache_k_scale_smem[col_idx + 16];
|
||||
k_tile_ptr0[17] = (frag_dq_T[9] - cache_k_zero_point_smem[col_idx + 17]) * cache_k_scale_smem[col_idx + 17];
|
||||
k_tile_ptr0[24] = (frag_dq_T[10] - cache_k_zero_point_smem[col_idx + 24]) * cache_k_scale_smem[col_idx + 24];
|
||||
k_tile_ptr0[25] = (frag_dq_T[11] - cache_k_zero_point_smem[col_idx + 25]) * cache_k_scale_smem[col_idx + 25];
|
||||
k_tile_ptr0[0] = frag_dq_T[0] * cache_k_scale_smem[col_idx] + cache_k_zero_point_smem[col_idx];
|
||||
k_tile_ptr0[1] = frag_dq_T[1] * cache_k_scale_smem[col_idx + 1] + cache_k_zero_point_smem[col_idx + 1];
|
||||
k_tile_ptr0[8] = frag_dq_T[2] * cache_k_scale_smem[col_idx + 8] + cache_k_zero_point_smem[col_idx + 8];
|
||||
k_tile_ptr0[9] = frag_dq_T[3] * cache_k_scale_smem[col_idx + 9] + cache_k_zero_point_smem[col_idx + 9];
|
||||
k_tile_ptr0[16] = frag_dq_T[8] * cache_k_scale_smem[col_idx + 16] + cache_k_zero_point_smem[col_idx + 16];
|
||||
k_tile_ptr0[17] = frag_dq_T[9] * cache_k_scale_smem[col_idx + 17] + cache_k_zero_point_smem[col_idx + 17];
|
||||
k_tile_ptr0[24] = frag_dq_T[10] * cache_k_scale_smem[col_idx + 24] + cache_k_zero_point_smem[col_idx + 24];
|
||||
k_tile_ptr0[25] = frag_dq_T[11] * cache_k_scale_smem[col_idx + 25] + cache_k_zero_point_smem[col_idx + 25];
|
||||
}
|
||||
|
||||
if (row_idx + 8 < end_idx) {
|
||||
k_tile_ptr1[0] = (frag_dq_T[4] - cache_k_zero_point_smem[col_idx]) * cache_k_scale_smem[col_idx];
|
||||
k_tile_ptr1[1] = (frag_dq_T[5] - cache_k_zero_point_smem[col_idx + 1]) * cache_k_scale_smem[col_idx + 1];
|
||||
k_tile_ptr1[8] = (frag_dq_T[6] - cache_k_zero_point_smem[col_idx + 8]) * cache_k_scale_smem[col_idx + 8];
|
||||
k_tile_ptr1[9] = (frag_dq_T[7] - cache_k_zero_point_smem[col_idx + 9]) * cache_k_scale_smem[col_idx + 9];
|
||||
k_tile_ptr1[16] = (frag_dq_T[12] - cache_k_zero_point_smem[col_idx + 16]) * cache_k_scale_smem[col_idx + 16];
|
||||
k_tile_ptr1[17] = (frag_dq_T[13] - cache_k_zero_point_smem[col_idx + 17]) * cache_k_scale_smem[col_idx + 17];
|
||||
k_tile_ptr1[24] = (frag_dq_T[14] - cache_k_zero_point_smem[col_idx + 24]) * cache_k_scale_smem[col_idx + 24];
|
||||
k_tile_ptr1[25] = (frag_dq_T[15] - cache_k_zero_point_smem[col_idx + 25]) * cache_k_scale_smem[col_idx + 25];
|
||||
k_tile_ptr1[0] = frag_dq_T[4] * cache_k_scale_smem[col_idx] + cache_k_zero_point_smem[col_idx];
|
||||
k_tile_ptr1[1] = frag_dq_T[5] * cache_k_scale_smem[col_idx + 1] + cache_k_zero_point_smem[col_idx + 1];
|
||||
k_tile_ptr1[8] = frag_dq_T[6] * cache_k_scale_smem[col_idx + 8] + cache_k_zero_point_smem[col_idx + 8];
|
||||
k_tile_ptr1[9] = frag_dq_T[7] * cache_k_scale_smem[col_idx + 9] + cache_k_zero_point_smem[col_idx + 9];
|
||||
k_tile_ptr1[16] = frag_dq_T[12] * cache_k_scale_smem[col_idx + 16] + cache_k_zero_point_smem[col_idx + 16];
|
||||
k_tile_ptr1[17] = frag_dq_T[13] * cache_k_scale_smem[col_idx + 17] + cache_k_zero_point_smem[col_idx + 17];
|
||||
k_tile_ptr1[24] = frag_dq_T[14] * cache_k_scale_smem[col_idx + 24] + cache_k_zero_point_smem[col_idx + 24];
|
||||
k_tile_ptr1[25] = frag_dq_T[15] * cache_k_scale_smem[col_idx + 25] + cache_k_zero_point_smem[col_idx + 25];
|
||||
}
|
||||
col_idx += 32;
|
||||
}
|
||||
@@ -685,7 +681,7 @@ __global__ void append_cache_kv_c4(
|
||||
tid % 2 * num_elems_per_128b<CacheT>();
|
||||
// load v_smem 128 rows 64 rows
|
||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||
v_smem_offset_w =
|
||||
@@ -704,7 +700,7 @@ __global__ void append_cache_kv_c4(
|
||||
// deal v_smem 128 rows 64 cols
|
||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter
|
||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
|
||||
uint32_t kv_idx = fz * 64 + tid % 4 * 2;
|
||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
|
||||
// layout
|
||||
@@ -715,36 +711,36 @@ __global__ void append_cache_kv_c4(
|
||||
convert_int4(frag_dq_T, v_frag[2 * i]);
|
||||
convert_int4(frag_dq_T + 8, v_frag[2 * i + 1]);
|
||||
if (kv_idx < end_idx) {
|
||||
v_tile_ptr0[0] = (frag_dq_T[0] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[0] = (frag_dq_T[4] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
v_tile_ptr0[0] = frag_dq_T[0] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[0] = frag_dq_T[4] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 1 < end_idx) {
|
||||
v_tile_ptr0[kv_t_stride] = (frag_dq_T[1] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[kv_t_stride] = (frag_dq_T[5] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
v_tile_ptr0[kv_t_stride] = frag_dq_T[1] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[kv_t_stride] = frag_dq_T[5] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 8 < end_idx) {
|
||||
v_tile_ptr0[8 * kv_t_stride] = (frag_dq_T[2] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[8 * kv_t_stride] = (frag_dq_T[6] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
v_tile_ptr0[8 * kv_t_stride] = frag_dq_T[2] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[8 * kv_t_stride] = frag_dq_T[6] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 9 < end_idx) {
|
||||
v_tile_ptr0[9 * kv_t_stride] = (frag_dq_T[3] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[9 * kv_t_stride] = (frag_dq_T[7] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
v_tile_ptr0[9 * kv_t_stride] = frag_dq_T[3] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[9 * kv_t_stride] = frag_dq_T[7] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 16 < end_idx) {
|
||||
v_tile_ptr0[16 * kv_t_stride] = (frag_dq_T[8] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[16 * kv_t_stride] = (frag_dq_T[12] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
v_tile_ptr0[16 * kv_t_stride] = frag_dq_T[8] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[16 * kv_t_stride] = frag_dq_T[12] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 17 < end_idx) {
|
||||
v_tile_ptr0[17 * kv_t_stride] = (frag_dq_T[9] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[17 * kv_t_stride] = (frag_dq_T[13] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
v_tile_ptr0[17 * kv_t_stride] = frag_dq_T[9] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[17 * kv_t_stride] = frag_dq_T[13] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 24 < end_idx) {
|
||||
v_tile_ptr0[24 * kv_t_stride] = (frag_dq_T[10] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[24 * kv_t_stride] = (frag_dq_T[14] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
v_tile_ptr0[24 * kv_t_stride] = frag_dq_T[10] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[24 * kv_t_stride] = frag_dq_T[14] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
}
|
||||
if (kv_idx + 25 < end_idx) {
|
||||
v_tile_ptr0[25 * kv_t_stride] = (frag_dq_T[11] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx];
|
||||
v_tile_ptr1[25 * kv_t_stride] = (frag_dq_T[15] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8];
|
||||
v_tile_ptr0[25 * kv_t_stride] = frag_dq_T[11] * cache_v_scale_smem[dim_idx] + cache_v_zero_point_smem[dim_idx];
|
||||
v_tile_ptr1[25 * kv_t_stride] = frag_dq_T[15] * cache_v_scale_smem[dim_idx + 8] + cache_v_zero_point_smem[dim_idx + 8];
|
||||
}
|
||||
kv_idx += 32;
|
||||
}
|
||||
@@ -894,8 +890,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const int kv_token_num,
|
||||
const int max_seq_len,
|
||||
const std::string& cache_quant_type,
|
||||
const bool rope_3d) {
|
||||
const std::string& cache_quant_type) {
|
||||
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
@@ -958,34 +953,9 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
|
||||
rotary_embs.dims()[2],
|
||||
head_dim,
|
||||
rope_3d,
|
||||
stream);
|
||||
|
||||
if (token_num < kv_token_num) {
|
||||
AppendCacheKV<data_t, 128, 64>(
|
||||
key_cache,
|
||||
value_cache,
|
||||
cache_k_dequant_scales.get(),
|
||||
cache_v_dequant_scales.get(),
|
||||
cache_k_zp.get(),
|
||||
cache_v_zp.get(),
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
cu_seqlens_k,
|
||||
block_tables,
|
||||
cache_batch_ids,
|
||||
cache_tile_ids,
|
||||
cache_num_blocks,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
cache_quant_type,
|
||||
&k,
|
||||
&v,
|
||||
stream
|
||||
);
|
||||
}
|
||||
// write cache
|
||||
if (cache_quant_type == "none") {
|
||||
CascadeAppendWriteCacheKVQKV<data_t>(
|
||||
@@ -1000,7 +970,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8") {
|
||||
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
|
||||
CascadeAppendWriteCacheKVC8QKV<data_t, 128, 64>(
|
||||
meta_data,
|
||||
*const_cast<paddle::Tensor*>(&key_cache),
|
||||
@@ -1018,7 +988,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
kv_num_blocks_data,
|
||||
max_seq_len,
|
||||
false, // is_scale_channel_wise
|
||||
cache_quant_type,
|
||||
cache_quant_type == "cache_fp8", // is_fp8
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
@@ -1068,6 +1038,30 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (token_num < kv_token_num) {
|
||||
AppendCacheKV<data_t, 128, 64>(
|
||||
key_cache,
|
||||
value_cache,
|
||||
cache_k_dequant_scales.get(),
|
||||
cache_v_dequant_scales.get(),
|
||||
cache_k_zp.get(),
|
||||
cache_v_zp.get(),
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
cu_seqlens_k,
|
||||
block_tables,
|
||||
cache_batch_ids,
|
||||
cache_tile_ids,
|
||||
cache_num_blocks,
|
||||
max_blocks_per_seq,
|
||||
kv_num_heads,
|
||||
cache_quant_type,
|
||||
&k,
|
||||
&v,
|
||||
stream
|
||||
);
|
||||
}
|
||||
return {q, k, v, qkv_out};
|
||||
}
|
||||
|
||||
|
@@ -18,168 +18,6 @@
|
||||
#include "mma_tensor_op.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, int VecSize = 1, typename InT = T>
|
||||
__global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
|
||||
// head_size]
|
||||
T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ q_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
const float*
|
||||
qkv_out_scales, // [(num_heads + 2 * gqa_group_size) * head_size]
|
||||
const T* qkv_biases, // [num_head + 2 * gqa_group_size, dim_head]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int output_inner_dim,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
LoadInT src_vec;
|
||||
LoadFloat scale_vec;
|
||||
LoadT bias_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
LoadFloat tmp_vec;
|
||||
LoadFloat q_norm_vec;
|
||||
LoadFloat k_norm_vec;
|
||||
|
||||
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||
int64_t all_head_dim = elem_cnt / head_size;
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size;
|
||||
const int half_head_size = head_size / 2;
|
||||
for (int global_hi = global_warp_idx; global_hi < all_head_dim; global_hi += all_warp_num) {
|
||||
int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize;
|
||||
const int token_id = linear_index / hidden_size;
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
if (block_idx < 0) {
|
||||
printf(
|
||||
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
|
||||
"%d %d %d %d\n",
|
||||
block_idx,
|
||||
write_seq_id,
|
||||
ori_bi,
|
||||
seq_lens_decoder[ori_bi],
|
||||
token_id,
|
||||
cu_seqlens_q[ori_bi]);
|
||||
}
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
const int write_q_idx =
|
||||
token_id * output_inner_dim * head_size + hi * head_size + h_bias;
|
||||
|
||||
const int bias_idx = hi * head_size + h_bias;
|
||||
Load<InT, VecSize>(&qkv[linear_index], &src_vec);
|
||||
if (qkv_biases) {
|
||||
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
|
||||
}
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &scale_vec);
|
||||
}
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
if (qkv_out_scales) {
|
||||
input_left *= scale_vec[2 * i];
|
||||
input_right *= scale_vec[2 * i + 1];
|
||||
}
|
||||
if (qkv_biases) {
|
||||
input_left = input_left + static_cast<float>(bias_vec[2 * i]);
|
||||
input_right = input_right + static_cast<float>(bias_vec[2 * i + 1]);
|
||||
}
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
tmp_vec[2 * i] = tmp1;
|
||||
tmp_vec[2 * i + 1] = tmp2;
|
||||
} else {
|
||||
bias_vec[2 * i] = static_cast<T>(input_left);
|
||||
bias_vec[2 * i + 1] = static_cast<T>(input_right);
|
||||
}
|
||||
}
|
||||
if (hi < (num_heads + gqa_group_size)) {
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / head_size, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
if (hi < num_heads) {
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
} else {
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (hi < num_heads) {
|
||||
// write q
|
||||
Store<T, VecSize>(bias_vec, &q_out[write_q_idx]);
|
||||
} else {
|
||||
// write k/v
|
||||
const int kv_head_idx = (hi - num_heads) % gqa_group_size;
|
||||
const int tgt_idx = (block_idx * gqa_group_size * block_size * head_size +
|
||||
kv_head_idx * block_size * head_size +
|
||||
block_offset * head_size + h_bias);
|
||||
// write
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
Store<T, VecSize>(bias_vec, &key_cache[tgt_idx]);
|
||||
} else {
|
||||
Store<T, VecSize>(bias_vec, &value_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int VecSize = 4, int HeadDim = 128>
|
||||
__global__ void append_clear_cache_int8_block(
|
||||
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
|
||||
@@ -355,8 +193,7 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
const int gqa_group_size) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
@@ -416,9 +253,8 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
@@ -490,8 +326,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
const int gqa_group_size) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
@@ -555,9 +390,8 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * head_size + h_bias;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2: emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
@@ -642,8 +476,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
const int gqa_group_size) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -689,9 +522,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
}
|
||||
@@ -751,11 +583,10 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
T scale;
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||
} else {
|
||||
scale = __ldg(&cache_v_scales[kv_head_idx]);
|
||||
@@ -877,8 +708,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
const int gqa_group_size) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -927,9 +757,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx_left],
|
||||
&left_out_scale_vec);
|
||||
@@ -1024,11 +853,10 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
|
||||
T scale;
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
|
||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
@@ -1260,8 +1088,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
const int gqa_group_size) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -1318,9 +1145,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
@@ -1409,11 +1235,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// &out_scale_vec2);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx], &scale_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx + 8], &scale_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
|
||||
@@ -1606,8 +1431,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
const int gqa_group_size) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -1757,11 +1581,10 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
&right_out_scale_vec2);
|
||||
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx],
|
||||
&left_scale_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx + 8],
|
||||
|
@@ -15,78 +15,6 @@
|
||||
#include "speculate_write_cache_with_rope_kernel.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
T* key_cache,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
const float* qkv_out_scales,
|
||||
const T* qkv_biases,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const bool rope_3d) {
|
||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||
const uint32_t elem_nums =
|
||||
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||
: token_num * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||
constexpr int HEAD_DIM = 128;
|
||||
|
||||
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
if (use_neox_style) {
|
||||
PD_THROW(
|
||||
"append_speculate_cache_rope_qk_norm not support neox rope yet");
|
||||
} else {
|
||||
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
|
||||
append_speculate_cache_T_rope_qk_norm_kernel<T, PackSize>
|
||||
<<<grid_size, block_dim, 0, stream>>>(qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales,
|
||||
qkv_biases,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
output_inner_dim,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
|
||||
// rope + write
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
@@ -111,8 +39,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
const bool use_neox_style) {
|
||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||
|
||||
const uint32_t elem_nums =
|
||||
@@ -146,8 +73,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
} else {
|
||||
append_speculate_cache_rope_kernel<T, PackSize>
|
||||
<<<grid_size, threads_per_block, 0, stream>>>(
|
||||
@@ -170,8 +96,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -200,8 +125,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
const bool use_neox_style) {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
@@ -243,8 +167,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
} else {
|
||||
append_speculate_cache_int8_rope_kernel<T, 4, 0, 128, QKV_TYPE, IsFP8>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
@@ -268,8 +191,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,8 +222,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
const bool use_neox_style) {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
@@ -345,8 +266,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
} else {
|
||||
append_speculate_cache_int4_rope_kernel<T, 4>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
@@ -372,8 +292,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
}
|
||||
}
|
||||
template <typename T, typename QKV_TYPE>
|
||||
@@ -394,15 +313,11 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
paddle::Tensor* value_cache_out) {
|
||||
typedef cascade_attn_type_traits<T> traits_;
|
||||
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
|
||||
typedef typename traits_::type DataType_;
|
||||
@@ -427,185 +342,142 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||
}
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope_qk_norm(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
reinterpret_cast<const float*>(q_norm_weight.get().data<float>()),
|
||||
reinterpret_cast<const float*>(k_norm_weight.get().data<float>()),
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||
}
|
||||
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
append_speculate_cache_int8_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_speculate_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
append_speculate_cache_int8_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_speculate_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, "
|
||||
"cache_int4_zp]");
|
||||
}
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, "
|
||||
"cache_int4_zp]");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -628,15 +500,11 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
||||
template void
|
||||
SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
@@ -658,15 +526,11 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
||||
template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -687,15 +551,11 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
||||
|
||||
template void
|
||||
@@ -718,12 +578,8 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
@@ -35,12 +35,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
@@ -56,7 +56,6 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -104,6 +103,5 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -99,6 +98,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -101,6 +100,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -101,6 +100,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -100,6 +99,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -100,6 +99,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -43,7 +43,4 @@ EncoderWriteCacheWithRopeKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
@@ -42,7 +42,4 @@ template void EncoderWriteCacheWithRopeKernel<paddle::bfloat16, int>(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
@@ -42,7 +42,4 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, paddle::float16>(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
@@ -42,7 +42,4 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, int>(
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
@@ -27,7 +27,6 @@ struct AppendAttnMetaData {
|
||||
int head_dims;
|
||||
int head_dims_v;
|
||||
int max_blocks_per_seq;
|
||||
const int *mask_offset = nullptr;
|
||||
};
|
||||
|
||||
__forceinline__ __host__ __device__ int div_up(int a, int b) {
|
||||
@@ -431,9 +430,6 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
} else if (group_size == 12) { \
|
||||
constexpr size_t GROUP_SIZE = 12; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 14) { \
|
||||
constexpr size_t GROUP_SIZE = 14; \
|
||||
__VA_ARGS__ \
|
||||
} else if (group_size == 16) { \
|
||||
constexpr size_t GROUP_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
@@ -441,15 +437,6 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
PD_THROW("not support the group_size", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \
|
||||
if (is_dynamic_cfp8) { \
|
||||
constexpr bool IsDynamicC8 = true; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
constexpr bool IsDynamicC8 = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
@@ -487,9 +474,6 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
if (causal) { \
|
||||
constexpr bool CAUSAL = true; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
constexpr bool CAUSAL = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
#define DISPATCH_ENABLE_PREFILL(enable_prefill, ENABLE_PREFILL, ...) \
|
||||
@@ -575,37 +559,3 @@ template <typename T, bool IsFP8>inline __device__ static void convert_c8(T * re
|
||||
convert_int8(result, source);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int kWarpSize = 32;
|
||||
|
||||
template<typename T>
|
||||
inline __device__ void WelfordCombine1(T b_m2, T* m2) {
|
||||
*m2 += b_m2;
|
||||
}
|
||||
|
||||
template<typename T, int thread_group_width = kWarpSize>
|
||||
__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) {
|
||||
*m2 = thread_m2;
|
||||
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
|
||||
T b_m2 = __shfl_xor_sync(0xffffffff, *m2, mask);
|
||||
WelfordCombine1(b_m2, m2);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, int thread_group_width = kWarpSize>
|
||||
__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) {
|
||||
WelfordWarpReduce<T, thread_group_width>(thread_m2, m2);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ T Rsqrt(T x);
|
||||
|
||||
template <>
|
||||
__inline__ __device__ float Rsqrt<float>(float x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ double Rsqrt<double>(double x) {
|
||||
return rsqrt(x);
|
||||
}
|
||||
|
@@ -63,7 +63,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::Tensor &kv_num_blocks,
|
||||
const paddle::Tensor &decoder_batch_ids,
|
||||
const paddle::Tensor &decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor &decoder_num_blocks_cpu,
|
||||
const paddle::Tensor &decoder_num_blocks,
|
||||
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
|
||||
const paddle::optional<paddle::Tensor> &rotary_embs,
|
||||
const paddle::optional<paddle::Tensor> &attn_mask,
|
||||
@@ -77,54 +77,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::optional<paddle::Tensor> &cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor> &out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor> &out_linear_smooths,
|
||||
const paddle::optional<paddle::Tensor> &mask_offset,
|
||||
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const std::string &compute_dtype, const std::string &cache_quant_type_str,
|
||||
const bool use_neox_rotary_style, const bool rope_3d,
|
||||
const int max_input_length, const float quant_max_bound,
|
||||
const float quant_min_bound, const float out_linear_in_scale,
|
||||
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
||||
const int max_partition_size, const int encoder_max_partition_size,
|
||||
const int speculate_max_draft_token_num, const bool causal,
|
||||
const bool speculate_decoder);
|
||||
|
||||
void AppendAttentionWithOutput(
|
||||
const paddle::Tensor &qkv, const paddle::Tensor &key_cache,
|
||||
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q,
|
||||
const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids,
|
||||
const paddle::Tensor &encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor &encoder_num_blocks,
|
||||
const paddle::Tensor &kv_batch_ids,
|
||||
const paddle::Tensor &kv_tile_ids_per_batch,
|
||||
const paddle::Tensor &kv_num_blocks,
|
||||
const paddle::Tensor &decoder_batch_ids,
|
||||
const paddle::Tensor &decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor &decoder_num_blocks_cpu,
|
||||
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
|
||||
paddle::Tensor &fmha_out,
|
||||
const paddle::optional<paddle::Tensor> &rotary_embs,
|
||||
const paddle::optional<paddle::Tensor> &attn_mask,
|
||||
const paddle::optional<paddle::Tensor> &qkv_bias,
|
||||
const paddle::optional<paddle::Tensor> &qkv_out_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_k_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_v_dequant_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_k_zp,
|
||||
const paddle::optional<paddle::Tensor> &cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor> &out_linear_shifts,
|
||||
const paddle::optional<paddle::Tensor> &out_linear_smooths,
|
||||
const paddle::optional<paddle::Tensor> &mask_offset,
|
||||
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const std::string &compute_dtype, const std::string &cache_quant_type_str,
|
||||
const bool use_neox_rotary_style, const bool rope_3d,
|
||||
const int max_input_length, const float quant_max_bound,
|
||||
@@ -154,8 +107,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::optional<paddle::Tensor> &cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
||||
const int kv_token_num, const int max_seq_len,
|
||||
const std::string &cache_quant_type,
|
||||
const bool rope_3d);
|
||||
const std::string &cache_quant_type);
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,
|
||||
@@ -172,29 +124,11 @@ paddle::Tensor FusedExpertMoeFunc(
|
||||
const std::string &quant_method, const int moe_topk,
|
||||
const bool norm_topk_prob, const bool group_moe);
|
||||
|
||||
std::vector<paddle::Tensor> MacheteMMKernel(
|
||||
paddle::Tensor const& A, paddle::Tensor const& B,
|
||||
paddle::optional<paddle::Tensor> const& maybe_group_scales,
|
||||
paddle::optional<paddle::Tensor> const& maybe_group_zeros,
|
||||
paddle::optional<paddle::Tensor> const& maybe_channel_scales,
|
||||
paddle::optional<paddle::Tensor> const& maybe_token_scales,
|
||||
std::string const& b_type_str,
|
||||
std::string const& maybe_out_type_str,
|
||||
int64_t const& maybe_group_size,
|
||||
std::string const& maybe_schedule);
|
||||
|
||||
std::vector<paddle::Tensor> MachetePrepackBKernel(
|
||||
paddle::Tensor const& B, std::string const& a_type_str, std::string const& b_type_str,
|
||||
std::string const& maybe_group_scales_type_str);
|
||||
|
||||
std::vector<std::string> MacheteSupportedSchedules(
|
||||
std::string const& a_type_str, std::string const& b_type_str);
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
const paddle::Tensor &input, const paddle::Tensor &gating_output,
|
||||
const paddle::optional<paddle::Tensor> &gating_correction_bias,
|
||||
const paddle::optional<paddle::Tensor> &w4a8_in_scale, const int moe_topk,
|
||||
const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode);
|
||||
const bool group_moe, const bool topk_only_mode);
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
MoETopKSelectKernel(const paddle::Tensor &gating_logits,
|
||||
@@ -254,9 +188,7 @@ paddle::Tensor MoeExpertFFNFunc(
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_in_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency,
|
||||
const int estimate_total_token_nums,
|
||||
const int hadamard_block_size);
|
||||
const std::string& quant_method, const bool used_in_ep_low_latency);
|
||||
|
||||
paddle::Tensor MoeExpertFFNWint2Func(
|
||||
const paddle::Tensor& permute_input,
|
||||
@@ -299,27 +231,12 @@ paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, const int device_id,
|
||||
paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata,
|
||||
const int layer_id);
|
||||
|
||||
void GetBlockShapeAndSplitKVBlock(
|
||||
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_device, // Inplace
|
||||
paddle::Tensor &decoder_chunk_size_device, // Inplace
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &kv_batch_ids, // Inplace
|
||||
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_kv_cpu, // Inplace, Pinned Memory
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
const int block_size,
|
||||
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
||||
const int group_size, const int block_size,
|
||||
const int decoder_step_token_num);
|
||||
|
||||
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||
@@ -349,12 +266,13 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &seq_lens,
|
||||
const paddle::Tensor &end_ids,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &stop_seqs,
|
||||
const paddle::Tensor &stop_seqs_len,
|
||||
const bool beam_search);
|
||||
|
||||
void GetStopFlagsMultiSeqs(
|
||||
const paddle::Tensor &topk_ids, const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens, const paddle::Tensor &stop_seqs,
|
||||
const paddle::Tensor &stop_seqs_len, const paddle::Tensor &end_ids);
|
||||
|
||||
void UpdateInputes(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor ¬_need_stop, // only on cpu
|
||||
@@ -388,11 +306,9 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::optional<paddle::Tensor> &draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
|
||||
const int block_size,
|
||||
const int max_draft_tokens);
|
||||
const int block_size);
|
||||
|
||||
|
||||
|
||||
paddle::Tensor
|
||||
GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor,
|
||||
@@ -402,7 +318,7 @@ std::vector<paddle::Tensor> ExtractTextTokenOutput(
|
||||
const paddle::Tensor &max_seq_len, const paddle::Tensor &max_seq_len_index,
|
||||
const paddle::Tensor &mm_token_num_len,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &hidden_states);
|
||||
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &score_text);
|
||||
|
||||
std::vector<paddle::Tensor> MoEDeepGEMMPermute(const paddle::Tensor &x,
|
||||
const paddle::Tensor &topk_idx,
|
||||
@@ -416,8 +332,8 @@ std::vector<paddle::Tensor> MoEDeepGEMMDePermute(
|
||||
const paddle::Tensor &topk_idx, const paddle::Tensor &topk_weights);
|
||||
|
||||
void TextImageIndexOut(const paddle::Tensor &token_type_ids,
|
||||
paddle::Tensor &text_input,
|
||||
paddle::Tensor &image_input);
|
||||
const paddle::Tensor &text_input,
|
||||
const paddle::Tensor &image_input);
|
||||
|
||||
void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
|
||||
paddle::Tensor &image_input,
|
||||
@@ -475,18 +391,23 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks_device,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
@@ -600,7 +521,7 @@ paddle::Tensor FusedHadamardQuantFp8Func(
|
||||
int64_t init_custom_all_reduce(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
paddle::Tensor& rank_data, int64_t rank, bool full_nvlink);
|
||||
|
||||
void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, int64_t _fa,
|
||||
void all_reduce(int64_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
|
||||
int64_t reg_buffer, int64_t reg_buffer_sz_bytes);
|
||||
|
||||
void dispose(int64_t _fa);
|
||||
@@ -683,7 +604,7 @@ void SpeculateVerify(
|
||||
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
|
||||
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
|
||||
|
||||
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
|
||||
void SpeculateUpdateV3(const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
@@ -714,22 +635,6 @@ void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
|
||||
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& seq_lens_decoder);
|
||||
|
||||
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &step_draft_tokens,
|
||||
const paddle::Tensor &step_seq_lens_this_time,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const int block_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
void NgramMatch(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &input_ids_len,
|
||||
const paddle::Tensor &pre_ids,
|
||||
@@ -744,20 +649,6 @@ void NgramMatch(const paddle::Tensor &input_ids,
|
||||
const int max_draft_tokens);
|
||||
|
||||
|
||||
void HybridMtpNgram(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &input_ids_len,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &draft_token_num,
|
||||
const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &max_dec_len,
|
||||
const int max_ngram_size,
|
||||
const int min_ngram_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
|
||||
// MTP
|
||||
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
@@ -773,12 +664,9 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& base_model_seq_lens_this_time,
|
||||
const paddle::Tensor& base_model_seq_lens_encoder,
|
||||
const paddle::Tensor& base_model_seq_lens_decoder,
|
||||
const paddle::Tensor& base_model_step_idx,
|
||||
@@ -787,8 +675,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_draft_token,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1);
|
||||
const bool splitwise_prefill);
|
||||
|
||||
|
||||
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
@@ -869,33 +756,6 @@ void SpeculateStepPaddle(
|
||||
const int encoder_decoder_block_num,
|
||||
const int max_draft_tokens);
|
||||
|
||||
void MergePrefillDecodeOutput(
|
||||
const paddle::Tensor &encoder_res,
|
||||
const paddle::Tensor &decoder_res,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &cu_seq_q,
|
||||
const int head_num,
|
||||
const int head_dim,
|
||||
const int max_token);
|
||||
|
||||
std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
|
||||
const paddle::Tensor &top_p,
|
||||
const paddle::optional<paddle::Tensor> &top_k,
|
||||
int64_t seed);
|
||||
|
||||
std::vector<paddle::Tensor> TopKRenorm(const paddle::Tensor &probs,
|
||||
const paddle::Tensor &top_k);
|
||||
|
||||
std::vector<paddle::Tensor> MinPSamplingFromProbs(const paddle::Tensor &probs,
|
||||
const paddle::Tensor &min_p);
|
||||
|
||||
void SaveOutMmsgStatic(const paddle::Tensor& x,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
int64_t rank_id,
|
||||
bool save_each_rank);
|
||||
|
||||
PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
|
||||
@@ -949,7 +809,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
* append_attention
|
||||
*/
|
||||
m.def("append_attention", &AppendAttention, "append attention function");
|
||||
m.def("append_attention_with_output", &AppendAttentionWithOutput, "append attention with output function");
|
||||
/**
|
||||
* gqa_rope_write_cache.cu
|
||||
* gqa_rope_write_cache
|
||||
@@ -981,7 +840,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("moe_expert_dispatch", &MoeExpertDispatch, py::arg("input"),
|
||||
py::arg("gating_output"), py::arg("gating_correction_bias"),
|
||||
py::arg("w4a8_in_scale"), py::arg("moe_topk"), py::arg("group_moe"),
|
||||
py::arg("moe_quant_type"), py::arg("topk_only_mode"), "moe export dispatch function");
|
||||
py::arg("topk_only_mode"), "moe export dispatch function");
|
||||
|
||||
/**
|
||||
* moe/fused_moe/ep_moe_prefill_func.cu
|
||||
@@ -1005,33 +864,12 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("per_token_quant_padding", &PerTokenQuantPadding, py::arg("input"),
|
||||
py::arg("block_size"),
|
||||
"per token per block quant and padding transpose scale");
|
||||
"per token per block quant and padding tranpose scale");
|
||||
|
||||
m.def("masked_per_token_quant", &MaskedPerTokenQuant, py::arg("input"),
|
||||
py::arg("recv_expert_count"), py::arg("block_size"),
|
||||
"per token per block quant");
|
||||
|
||||
#ifdef ENABLE_MACHETE
|
||||
/*machete/machete_mm.cu
|
||||
* machete_mm
|
||||
*/
|
||||
m.def("machete_mm", &MacheteMMKernel, py::arg("A"), py::arg("B"), py::arg("maybe_group_scale"),
|
||||
py::arg("maybe_group_zeros"), py::arg("maybe_channel_scales"), py::arg("maybe_token_scales"),
|
||||
py::arg("b_type_str"), py::arg("maybe_out_type_str"), py::arg("maybe_group_size"),
|
||||
py::arg("maybe_schedule"),
|
||||
"machete mm function");
|
||||
|
||||
/*machete/machete_prepack_B.cu
|
||||
* machete_prepack_B
|
||||
*/
|
||||
m.def("machete_prepack_B", &MachetePrepackBKernel, "machete prepacked B function");
|
||||
|
||||
/*machete/machete_supported_schedules.cu
|
||||
* machete_supported_schedules
|
||||
*/
|
||||
m.def("machete_supported_schedules", &MacheteSupportedSchedules, "machete supported schedules function");
|
||||
#endif
|
||||
|
||||
/**
|
||||
* moe/fused_moe/moe_topk_select.cu
|
||||
* moe_topk_select
|
||||
@@ -1048,7 +886,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("moe_expert_ffn", &MoeExpertFFNFunc, "moe export ffn function");
|
||||
|
||||
/**
|
||||
* moe/fused_moe/moe_expert_ffn_wint2.cu
|
||||
* moe/fused_moe/moe_ffn_wint2.cu
|
||||
* moe_expert_ffn_wint2
|
||||
*/
|
||||
m.def("moe_expert_ffn_wint2", &MoeExpertFFNWint2Func, "moe export ffn wint2 function");
|
||||
@@ -1116,6 +954,12 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("set_stop_value_multi_ends", &GetStopFlagsMulti,
|
||||
"update_inputs function");
|
||||
|
||||
/**
|
||||
* stop_generation_multi_stop_seqs.cu
|
||||
* set_stop_value_multi_seqs
|
||||
*/
|
||||
m.def("set_stop_value_multi_seqs", &GetStopFlagsMultiSeqs,
|
||||
"update_inputs function");
|
||||
|
||||
/**
|
||||
* update_inputs.cu
|
||||
@@ -1245,7 +1089,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("speculate_verify",&SpeculateVerify, "speculate_verify function");
|
||||
|
||||
m.def("speculate_update",&SpeculateUpdate, "Speculate Update Kernel");
|
||||
m.def("speculate_update_v3",&SpeculateUpdateV3, "noaux_tc for Deepseekv3 MoE compute function");
|
||||
|
||||
m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function");
|
||||
|
||||
@@ -1253,12 +1097,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function");
|
||||
|
||||
m.def("speculate_schedule_cache",&SpeculateScheduleCache, "SpeculateScheduleCache function");
|
||||
|
||||
m.def("ngram_match", &NgramMatch, "ngram_match function");
|
||||
|
||||
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");
|
||||
|
||||
m.def("draft_model_postprocess",&DraftModelPostprocess, "draft_model_postprocess function");
|
||||
|
||||
m.def("draft_model_preprocess",&DraftModelPreprocess, "draft_model_preprocess function");
|
||||
@@ -1272,14 +1112,4 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function");
|
||||
|
||||
m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function");
|
||||
|
||||
m.def("merge_prefill_decode_output", &MergePrefillDecodeOutput, "merge_prefill_decode_output function");
|
||||
|
||||
m.def("rejection_top_p_sampling", &TopPSamplingReject, "rejection_top_p_sampling function");
|
||||
|
||||
m.def("top_k_renorm_probs", &TopKRenorm, "top_k_renorm_probs function");
|
||||
|
||||
m.def("min_p_sampling", &MinPSamplingFromProbs, "min_p_sampling function");
|
||||
|
||||
m.def("save_output", &SaveOutMmsgStatic, "save_output function");
|
||||
}
|
||||
|
@@ -49,7 +49,7 @@ fptr_t init_custom_all_reduce(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
|
||||
* copied into _reg_buffer.
|
||||
*/
|
||||
void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, fptr_t _fa,
|
||||
void all_reduce(fptr_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
|
||||
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
auto stream = inp.stream();
|
||||
@@ -163,12 +163,3 @@ fptr_t open_mem_handle(paddle::Tensor& mem_handle) {
|
||||
void free_shared_buffer(fptr_t buffer) {
|
||||
CUDACHECK(cudaFree(reinterpret_cast<void*>(buffer)));
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(all_reduce)
|
||||
.Inputs({"inp",
|
||||
"out"})
|
||||
.Outputs({"new_out"})
|
||||
.Attrs({"_fa: int64_t", "_reg_buffer: int64_t", "reg_buffer_sz_bytes: int64_t"})
|
||||
.SetInplaceMap({{"out", "new_out"}})
|
||||
.SetKernelFn(PD_KERNEL(all_reduce));
|
||||
|
@@ -6,8 +6,6 @@
|
||||
// clang-format off
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
|
||||
|
||||
#include "helper.h"
|
||||
// clang-format on
|
||||
|
||||
/*
|
||||
|
@@ -133,18 +133,10 @@ public:
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint2b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value; // 64
|
||||
|
||||
private:
|
||||
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint2b_t>::value;
|
||||
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; // 8
|
||||
|
||||
public:
|
||||
// using Layout = layout::ColumnMajor;
|
||||
// static constexpr int ElementsPerAccess = 16; // at least 4-bytes
|
||||
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint2b_t>::value; // 64
|
||||
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::RowMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
|
@@ -18,12 +18,14 @@
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -376,23 +378,38 @@ template <
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
private:
|
||||
using Mma = DefaultWint2xMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
|
||||
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
};
|
||||
|
||||
template <
|
||||
@@ -424,23 +441,38 @@ struct DefaultMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
private:
|
||||
using Mma = DefaultWint2xMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
|
||||
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
|
@@ -19,7 +19,7 @@
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@@ -379,23 +379,38 @@ template <
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
private:
|
||||
using Mma = DefaultWint2xMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
|
||||
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
};
|
||||
|
||||
template <
|
||||
@@ -427,23 +442,38 @@ struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmen
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
private:
|
||||
using Mma = DefaultWint2xMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
|
||||
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
|
@@ -1,182 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/// Partial specialization:
|
||||
///
|
||||
/// A: row-major
|
||||
/// B: uint2b_t, column-major
|
||||
/// Operator: tensor op class
|
||||
///
|
||||
/// This uses the default warp-level operator given tile sizes
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator (concept:
|
||||
/// GemmShape)
|
||||
typename Shape_,
|
||||
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Data type of A operand
|
||||
typename ElementA_,
|
||||
/// Data type of accumulator
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC_,
|
||||
/// Number of stages
|
||||
int Stages,
|
||||
/// Operation performed by MMA
|
||||
typename Operator_,
|
||||
/// Cache operation of operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Cache operation of operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB>
|
||||
struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
|
||||
layout::RowMajor, uint2b_t, layout::ColumnMajor,
|
||||
ElementC_, LayoutC_, arch::OpClassTensorOp, Stages,
|
||||
Operator_, false, CacheOpA, CacheOpB> {
|
||||
using Shape = Shape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = layout::RowMajor;
|
||||
using ElementB = uint2b_t;
|
||||
using LayoutB = layout::ColumnMajor;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
static int const kStages = Stages;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
|
||||
Shape::kN / WarpShape::kN,
|
||||
Shape::kK / WarpShape::kK>;
|
||||
|
||||
// Divisility requirements
|
||||
static_assert(
|
||||
!(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
|
||||
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
|
||||
|
||||
/// Size of a threadblock-scoped access
|
||||
static int const kAccessSizeInBits = 128;
|
||||
|
||||
/// Number of threads total
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
|
||||
/// Size of a threadblock-scoped access of B
|
||||
static constexpr int kMaxThreadsForB =
|
||||
(Shape::kK * Shape::kN * sizeof_bits<ElementB>::value) / kAccessSizeInBits;
|
||||
static constexpr int kThreadsForB =
|
||||
kMaxThreadsForB > kThreads ? kThreads : kMaxThreadsForB;
|
||||
|
||||
/// Default Operator
|
||||
using Operator = Operator_;
|
||||
|
||||
// Warp thread arrangement
|
||||
static int const kWarpThreadArrangementContiguousA =
|
||||
Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementA>::value);
|
||||
|
||||
static int const kWarpThreadArrangementStridedA =
|
||||
kWarpSize / kWarpThreadArrangementContiguousA;
|
||||
|
||||
static int const kWarpThreadArrangementContiguousB =
|
||||
Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementB>::value);
|
||||
|
||||
static int const kWarpThreadArrangementStridedB =
|
||||
kWarpSize / kWarpThreadArrangementContiguousB;
|
||||
|
||||
//
|
||||
// Shared memory layouts
|
||||
//
|
||||
|
||||
using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
sizeof_bits<ElementA>::value, Shape::kK>;
|
||||
|
||||
// Shared memory layout
|
||||
using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
sizeof_bits<ElementB>::value, Shape::kK>;
|
||||
|
||||
//
|
||||
// Iterators to write to shared memory
|
||||
//
|
||||
|
||||
/// ThreadMap of iterator A
|
||||
using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kK, Shape::kM>, kThreads,
|
||||
layout::PitchLinearShape<kWarpThreadArrangementContiguousA,
|
||||
kWarpThreadArrangementStridedA>,
|
||||
kAccessSizeInBits / sizeof_bits<ElementA>::value>;
|
||||
|
||||
/// Shared memory iterator to A operand
|
||||
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kM, Shape::kK>, ElementA, SmemLayoutA, 0,
|
||||
IteratorThreadMapA>;
|
||||
|
||||
/// ThreadMap of iterator B
|
||||
using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kK, Shape::kN>, kThreadsForB,
|
||||
layout::PitchLinearShape<kWarpThreadArrangementContiguousB,
|
||||
kWarpThreadArrangementStridedB>,
|
||||
kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
|
||||
/// Shared memory iterator to B operand
|
||||
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 1,
|
||||
IteratorThreadMapB>;
|
||||
|
||||
//
|
||||
// Warp-level matrix multiply operator
|
||||
//
|
||||
|
||||
// Define the warp-level tensor op
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
|
||||
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
|
||||
|
||||
/// Policy used to define MmaPipelined
|
||||
using MmaPolicy = MmaPolicy<MmaTensorOp, MatrixShape<0, 0>,
|
||||
MatrixShape<0, 0>, WarpCount::kK>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
@@ -1,246 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma_core.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename ThreadblockShape, typename ElementT, int GroupSize>
|
||||
struct DefaultQuantParamsIterators {
|
||||
private:
|
||||
static constexpr int kAlignment = 128 / sizeof_bits<ElementT>::value;
|
||||
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
|
||||
|
||||
static constexpr int kRows =
|
||||
(GroupSize == -1) ? 1 : (ThreadblockShape::kK + GroupSize - 1) / GroupSize;
|
||||
static constexpr int kColumns = ThreadblockShape::kN;
|
||||
|
||||
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<kColumns, kRows>,
|
||||
kColumns / kAlignment, kAlignment>;
|
||||
|
||||
public:
|
||||
using Iterator = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
MatrixShape<kRows, kColumns>, ElementT, layout::RowMajor, 0,
|
||||
IteratorThreadMap, kAlignment>;
|
||||
using SmemIterator = Iterator;
|
||||
};
|
||||
|
||||
template <typename ThreadblockShape, int GroupSize>
|
||||
struct DefaultQuantParamsIterators<ThreadblockShape, uint4b_t, GroupSize> {
|
||||
private:
|
||||
static constexpr int kAlignment = 32 / sizeof_bits<uint4b_t>::value;
|
||||
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
|
||||
|
||||
static constexpr int kRows =
|
||||
(GroupSize == -1) ? 1 : (ThreadblockShape::kK + 2 * GroupSize - 1) / (2 * GroupSize);
|
||||
static constexpr int kColumns =
|
||||
(GroupSize == -1) ? ThreadblockShape::kN : ThreadblockShape::kN * 2;
|
||||
|
||||
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<kColumns, kRows>,
|
||||
kColumns / kAlignment, kAlignment>;
|
||||
|
||||
public:
|
||||
using AccessType = cutlass::Array<uint4b_t, kAlignment>;
|
||||
using Iterator = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
MatrixShape<kRows, kColumns>, uint4b_t, layout::RowMajor,
|
||||
0, IteratorThreadMap, AccessType>;
|
||||
|
||||
using SmemIterator = Iterator;
|
||||
};
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone>
|
||||
struct DefaultWint2xMma;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Type for element A
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Type for element B
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Stages in GEMM
|
||||
int kStages,
|
||||
/// Operator performed by GEMM
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
|
||||
kStages, Operator, SharedMemoryClear>
|
||||
{
|
||||
public:
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
||||
"Element A must be fp16 or bf16");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint2b_t>::value,
|
||||
"Element B must be uint2b_t");
|
||||
|
||||
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
|
||||
"Mma multistage must dequantize after ldsm");
|
||||
|
||||
using ElementSuperScale = ElementA;
|
||||
using ElementLocalScale = uint4b_t;
|
||||
using ElementCodeScaleZp = float;
|
||||
|
||||
static constexpr int kGroupSize = 64;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
ElementA, LayoutA, ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
private:
|
||||
static constexpr int kColumnsInterleaved = LayoutB::kColumnsInterleaved;
|
||||
static constexpr int kRowsPerTile = LayoutB::kRowsPerTile;
|
||||
static_assert(!(MmaCore::Shape::kN % kColumnsInterleaved), "ThreadblockShape must be disivle by kColumnsInterleaved");
|
||||
static_assert(kRowsPerTile == MmaCore::Shape::kK, "");
|
||||
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using WarpArrangement = typename ThreadMapB::Detail::WarpThreadArrangement;
|
||||
static_assert(!(WarpArrangement::kStrided % kColumnsInterleaved), "");
|
||||
|
||||
using IteratorShapeB = MatrixShape<
|
||||
MmaCore::Shape::kK * kColumnsInterleaved, MmaCore::Shape::kN / kColumnsInterleaved>;
|
||||
using InterleavedThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<IteratorShapeB::kRow, IteratorShapeB::kColumn>,
|
||||
ThreadMapB::kThreads,
|
||||
layout::PitchLinearShape<WarpArrangement::kContiguous * kColumnsInterleaved,
|
||||
WarpArrangement::kStrided / kColumnsInterleaved>,
|
||||
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the B operand
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
IteratorShapeB, ElementB, layout::ColumnMajor, 0, InterleavedThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
private:
|
||||
// Define iterators over tiles from extra quant params for B operand
|
||||
using IteratorSuperScale = typename DefaultQuantParamsIterators<
|
||||
ThreadblockShape, ElementSuperScale, -1>::Iterator;
|
||||
using SmemIteratorSuperScale = typename DefaultQuantParamsIterators<
|
||||
ThreadblockShape, ElementSuperScale, -1>::SmemIterator;
|
||||
|
||||
using IteratorLocalScale = typename DefaultQuantParamsIterators<
|
||||
ThreadblockShape, ElementLocalScale, kGroupSize>::Iterator;
|
||||
using SmemIteratorLocalScale = typename DefaultQuantParamsIterators<
|
||||
ThreadblockShape, ElementLocalScale, kGroupSize>::SmemIterator;
|
||||
|
||||
using IteratorCodeScaleZp = typename DefaultQuantParamsIterators<
|
||||
ThreadblockShape, ElementCodeScaleZp, -1>::Iterator;
|
||||
using SmemIteratorCodeScaleZp = typename DefaultQuantParamsIterators<
|
||||
ThreadblockShape, ElementCodeScaleZp, -1>::Iterator;
|
||||
|
||||
public:
|
||||
using QuantParamsAccessor = Wint2ParamsAccessor<
|
||||
ElementA, ThreadblockShape, IteratorSuperScale, SmemIteratorSuperScale,
|
||||
IteratorLocalScale, SmemIteratorLocalScale,
|
||||
IteratorCodeScaleZp, SmemIteratorCodeScaleZp, kStages, kGroupSize>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<
|
||||
typename MmaCore::Shape,
|
||||
IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA,
|
||||
IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB,
|
||||
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy,
|
||||
kStages, QuantParamsAccessor, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
@@ -63,8 +63,8 @@ template <
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Size of extra quantized params
|
||||
typename QuantParamsShape>
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class Wint2xMmaBase {
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
@@ -89,18 +89,10 @@ public:
|
||||
GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM operations
|
||||
/// Number of warp-level GEMM oeprations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
/// Number of warp-level GEMM operations per load for B
|
||||
static constexpr int kWarpGemmIterationsPerLoadForB =
|
||||
Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
|
||||
static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), "");
|
||||
|
||||
static constexpr int kWarpLoadIterationsForB =
|
||||
kWarpGemmIterations / kWarpGemmIterationsPerLoadForB;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
@@ -112,6 +104,8 @@ public:
|
||||
using TensorRefB =
|
||||
TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
// using TensorRefZippedB = TensorRef<uint8_t, typename Operator::LayoutB>;
|
||||
|
||||
static_assert(kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
@@ -136,11 +130,20 @@ public:
|
||||
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
|
||||
using ShapeB = MatrixShape<Shape::kK + Policy::SmemPaddingB::kRow,
|
||||
Shape::kN + Policy::SmemPaddingB::kColumn>;
|
||||
|
||||
/// Shape of all quant params in shared memory
|
||||
using QuantParamsShapeB = QuantParamsShape;
|
||||
// w uint8; local_scale uint8;
|
||||
constexpr static int kZippedRowsPerStages =
|
||||
Shape::kK / 4 + (Shape::kK + 127) / 128;
|
||||
|
||||
// code_scale float; code_zp float; super_scale ElementB
|
||||
constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) +
|
||||
sizeof_bits<typename Operator::ElementB>::value / 8;
|
||||
|
||||
using ZippedShapeB = MatrixShape<kColumnWiseParamsRows + kZippedRowsPerStages * kStages, Shape::kN>;
|
||||
|
||||
using NopaddingShapeB = MatrixShape<Shape::kK, Shape::kN>;
|
||||
|
||||
public:
|
||||
//
|
||||
@@ -153,8 +156,12 @@ public:
|
||||
/// Buffer for B operand
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
|
||||
|
||||
/// Buffer for extra quant params of B operand
|
||||
AlignedBuffer<uint8_t, QuantParamsShapeB::kCount> operand_quant_params_B;
|
||||
/// Buffer for quanted B operand
|
||||
AlignedBuffer<uint8_t, ZippedShapeB::kCount> operand_zipped_B;
|
||||
|
||||
/// Buffer for unzip B operand
|
||||
AlignedBuffer<typename Operator::ElementB, NopaddingShapeB::kCount>
|
||||
operand_unzip_B;
|
||||
|
||||
public:
|
||||
//
|
||||
@@ -184,6 +191,14 @@ public:
|
||||
TensorRefB operand_B_ref() {
|
||||
return TensorRefB{operand_B.data(), LayoutB()};
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
typename Operator::ElementB *operand_unzip_B_ptr() {
|
||||
return operand_unzip_B.data();
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
|
@@ -45,8 +45,7 @@
|
||||
|
||||
#include "cutlass_extensions/arch/memory_copy_sm80.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -87,15 +86,15 @@ template <
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Accessor for extra quantized params
|
||||
typename QuantParamsAccessor_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone>
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class Wint2xMmaMultistage :
|
||||
public Wint2xMmaBase<Shape_, Policy_, Stages, typename QuantParamsAccessor_::QuantParamsShape> {
|
||||
public Wint2xMmaBase<Shape_, Policy_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = Wint2xMmaBase<Shape_, Policy_, Stages, typename QuantParamsAccessor_::QuantParamsShape>;
|
||||
using Base = Wint2xMmaBase<Shape_, Policy_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
@@ -108,11 +107,8 @@ public:
|
||||
using LayoutC = LayoutC_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
/// Accessor for extra quantized params
|
||||
using QuantParamsAccessor = QuantParamsAccessor_;
|
||||
using QuantArguments = typename QuantParamsAccessor::Arguments;
|
||||
|
||||
static constexpr int kInterleave = IteratorB::Shape::kRow / Shape::kK;
|
||||
using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
@@ -133,18 +129,6 @@ public:
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
//using LayoutScale = typename QuantParamsAccessor::IteratorSuperScale::Layout;
|
||||
using LayoutScale = layout::RowMajor;
|
||||
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
|
||||
using WarpDequantizer =
|
||||
warp::MmaTensorOpWin2xDequantizer<Operator,
|
||||
typename Base::WarpGemm,
|
||||
Operand::kB,
|
||||
typename WarpTransformedFragmentB::Element,
|
||||
LayoutScale,
|
||||
QuantParamsAccessor::kGroupSize>;
|
||||
static_assert(sizeof(WarpDequantizer) > 0, "WarpDequantizer template instantiation failed");
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
@@ -190,37 +174,18 @@ public:
|
||||
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
|
||||
|
||||
using FragmentSuperScale = typename WarpDequantizer::FragmentSuperScale;
|
||||
using FragmentCodeScaleZp = typename WarpDequantizer::FragmentCodeScaleZp;
|
||||
using FragmentLocalScale = typename WarpDequantizer::FragmentLocalScale;
|
||||
|
||||
/// Temporary accumulator to facilitate staged-accumulation
|
||||
FragmentC tmp_accum_;
|
||||
|
||||
/// Pair of A fragments used to overlap shared memory loads and math instructions
|
||||
WarpTransformedFragmentA warp_frag_A_[2];
|
||||
WarpLoadedFragmentA warp_loaded_frag_A_[2];
|
||||
WarpTransformedFragmentA warp_transformed_frag_A_[2];
|
||||
|
||||
/// Pair of B fragments used to overlap shared memory loads and math instructions
|
||||
WarpLoadedFragmentB warp_loaded_frag_B_;
|
||||
WarpTransformedFragmentB warp_frag_B_[2];
|
||||
|
||||
/// channel-wise quant params
|
||||
FragmentCodeScaleZp warp_frag_code_scale_;
|
||||
FragmentCodeScaleZp warp_frag_code_zp_;
|
||||
FragmentSuperScale warp_frag_super_scale_;
|
||||
|
||||
/// group-wise quant params
|
||||
FragmentLocalScale warp_frag_local_scale_;
|
||||
WarpLoadedFragmentB warp_loaded_frag_B_[2];
|
||||
WarpTransformedFragmentB warp_transformed_frag_B_[2];
|
||||
};
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool IsTileInterleaveLayout =
|
||||
layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!IsTileInterleaveLayout || (IsTileInterleaveLayout && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
private:
|
||||
|
||||
@@ -237,18 +202,17 @@ public:
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Accessor for extra quant params for B
|
||||
QuantParamsAccessor quant_params_accessor_B_;
|
||||
|
||||
// Wint2 unzip operator
|
||||
WarpDequantizer warp_dequantizer_;
|
||||
|
||||
/// Shared memory write stage index
|
||||
int smem_write_stage_idx_;
|
||||
|
||||
/// Shared memory read stage index
|
||||
int smem_read_stage_idx_;
|
||||
|
||||
uint8_t* column_wise_smem_ptr_B_;
|
||||
|
||||
uint8_t* smem_zipped_ptr_B_;
|
||||
int smem_zipped_bytes_per_stage_B_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
@@ -262,15 +226,10 @@ public:
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
) : Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
|
||||
quant_params_accessor_B_(shared_storage.operand_quant_params_B.data(), thread_idx, warp_idx, lane_idx),
|
||||
warp_dequantizer_(quant_params_accessor_B_.super_scale_ref(),
|
||||
quant_params_accessor_B_.local_scale_ref(),
|
||||
quant_params_accessor_B_.code_scale_ref(),
|
||||
quant_params_accessor_B_.code_zp_ref(),
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx),
|
||||
smem_write_stage_idx_(0),
|
||||
smem_read_stage_idx_(0)
|
||||
{
|
||||
@@ -291,6 +250,11 @@ public:
|
||||
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
||||
|
||||
column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr();
|
||||
|
||||
smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn;
|
||||
smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn;
|
||||
}
|
||||
|
||||
/// Advance shared memory read-iterators to the next stage
|
||||
@@ -302,22 +266,28 @@ public:
|
||||
if (smem_read_stage_idx_ == Base::kStages) {
|
||||
// Wrap back around to the 'start' of the circular buffer in shared memory
|
||||
this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpLoadIterationsForB, 0});
|
||||
// this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
|
||||
smem_read_stage_idx_ = 0;
|
||||
}
|
||||
this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
|
||||
}
|
||||
|
||||
/// Advance global memory read-iterators and shared memory write-iterators to the stage
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void advance_smem_write_stage(IteratorA &iterator_A, IteratorB &iterator_B)
|
||||
void advance_smem_write_stage(
|
||||
IteratorA &iterator_A,
|
||||
IteratorB &iterator_B,
|
||||
TileDequanterB &tile_dequanter_B)
|
||||
{
|
||||
// Advance global iterators
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
//iterator_B.add_tile_offset({1, 0});
|
||||
tile_dequanter_B.AddTileOffset({1, 0});
|
||||
|
||||
// Advance shared iterators
|
||||
smem_iterator_A_.add_tile_offset({0, 1});
|
||||
smem_iterator_B_.add_tile_offset({1, 0});
|
||||
//smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Increment shared memory write stage index
|
||||
++smem_write_stage_idx_;
|
||||
@@ -325,7 +295,7 @@ public:
|
||||
if (smem_write_stage_idx_ == Base::kStages) {
|
||||
// Wrap back around to the 'start' of the circular buffer in shared memory
|
||||
smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
//smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx_ = 0;
|
||||
}
|
||||
}
|
||||
@@ -368,14 +338,9 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
template <bool GlobalToSharedB>
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_B(IteratorB &iterator_B, int group_start_B = 0) {
|
||||
if constexpr (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
if (threadIdx.x >= IteratorB::ThreadMap::kThreads) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(group_start_B *
|
||||
IteratorB::kAccessesPerVector);
|
||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||
@@ -395,14 +360,13 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
bool is_valid = (threadIdx.x < IteratorB::ThreadMap::kThreads) ? iterator_B.valid() : false;
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, gmem_ptr, is_valid);
|
||||
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
} else {
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, gmem_ptr, is_valid);
|
||||
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
|
||||
++iterator_B;
|
||||
@@ -411,6 +375,7 @@ public:
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
@@ -434,6 +399,8 @@ public:
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
@@ -444,12 +411,9 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
template <bool GlobalToSharedB, bool InitStage>
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) {
|
||||
if (threadIdx.x >= IteratorB::ThreadMap::kThreads) {
|
||||
return;
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(0);
|
||||
this->smem_iterator_B_.set_iteration_index(0);
|
||||
|
||||
@@ -469,23 +433,35 @@ public:
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
if (InitStage) {
|
||||
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
} else {
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
} else {
|
||||
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
}
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching
|
||||
/// the global fragments needed by the first kStages-1 threadblock mainloop iterations
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void prologue(
|
||||
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
|
||||
QuantArguments &mma_quant_args, ///< iterators for extra quant params for B
|
||||
TileDequanterB &tile_dequanter_B,
|
||||
int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining
|
||||
{
|
||||
// Issue several complete stages
|
||||
@@ -500,18 +476,11 @@ public:
|
||||
copy_tiles_and_advance_per_stage_A(iterator_A);
|
||||
|
||||
// Async copy zipped B to shared memory.
|
||||
copy_tiles_and_advance_per_stage_B(iterator_B);
|
||||
|
||||
// Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale.
|
||||
if (stage == 0) {
|
||||
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<true>(mma_quant_args, stage);
|
||||
} else {
|
||||
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(mma_quant_args, stage);
|
||||
}
|
||||
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
|
||||
column_wise_smem_ptr_B_, stage);
|
||||
|
||||
// Move to the next write stage
|
||||
advance_smem_write_stage(iterator_A, iterator_B);
|
||||
quant_params_accessor_B_.advance_smem_write_stage(mma_quant_args);
|
||||
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
@@ -541,10 +510,6 @@ public:
|
||||
++last_smem_iterator_A;
|
||||
}
|
||||
|
||||
if (threadIdx.x >= IteratorB::ThreadMap::kThreads) {
|
||||
return;
|
||||
}
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
|
||||
typename IteratorB::AccessType zero_B;
|
||||
@@ -577,57 +542,57 @@ public:
|
||||
}
|
||||
|
||||
/// Perform a threadblock mainloop iteration of matrix multiply-accumulate
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void mac_loop_iter(
|
||||
PipeState &pipe_state, ///< [in|out] loop-carried pipeline state
|
||||
FragmentC &accum, ///< [in|out] destination accumulator tile
|
||||
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
|
||||
QuantArguments &mma_quant_args, ///< iterators for extra quant params for B
|
||||
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining
|
||||
TileDequanterB &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand
|
||||
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining
|
||||
int stage)
|
||||
{
|
||||
const int mma_stage = stage - Base::kStages + 1;
|
||||
|
||||
// Unroll the warp-level MMA tiles of a threadblock's mainloop iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
|
||||
|
||||
int warp_k_compute_offset_B = warp_mma_k % Base::kWarpGemmIterationsPerLoadForB;
|
||||
|
||||
if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) {
|
||||
// Load the next warp-tile's B fragment from shared memory
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(((warp_mma_k + 1) % Base::kWarpGemmIterations) / Base::kWarpLoadIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
// load next-tile of group-wise local_scale from shared memory
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
|
||||
warp_dequantizer_.load(pipe_state.warp_frag_local_scale_);
|
||||
}
|
||||
// CUTLASS_TRACE_DEVICE(" [MMa] stage=%d, warp_mma_k=%d", stage, warp_mma_k);
|
||||
|
||||
// Load the next warp-tile's A fragment from shared memory
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
// dequantizes next warp-tile
|
||||
warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_,
|
||||
pipe_state.warp_frag_code_scale_,
|
||||
pipe_state.warp_frag_code_zp_,
|
||||
pipe_state.warp_frag_super_scale_,
|
||||
pipe_state.warp_loaded_frag_B_,
|
||||
pipe_state.warp_frag_B_[(warp_mma_k + 1) % 2],
|
||||
((warp_mma_k == Base::kWarpGemmIterations - 1) ? (mma_stage + 1) : mma_stage) * Shape::kK,
|
||||
(warp_mma_k + 1) % Base::kWarpGemmIterationsPerLoadForB);
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
|
||||
// Unpack and dequant the first stage of B.
|
||||
int unpack_stage = stage - Base::kStages + 2;
|
||||
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
|
||||
column_wise_smem_ptr_B_, unpack_stage);
|
||||
|
||||
// Copy dequatized data to shared memory used by mma core.
|
||||
copy_tiles_and_advance_per_stage_B<false, false>(iterator_B);
|
||||
}
|
||||
|
||||
// Load the next warp-tile's B fragment from shared memory
|
||||
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
// Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary
|
||||
if (warp_mma_k > 0) {
|
||||
warp_mma_.transform(
|
||||
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
|
||||
pipe_state.warp_loaded_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]);
|
||||
}
|
||||
|
||||
// Execute the current warp-tile of MMA operations
|
||||
if constexpr (Detail::kStagedAccumulation) {
|
||||
if (Detail::kStagedAccumulation) {
|
||||
warp_mma_(
|
||||
pipe_state.tmp_accum_,
|
||||
pipe_state.warp_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_frag_B_[warp_mma_k % 2],
|
||||
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
|
||||
pipe_state.tmp_accum_
|
||||
);
|
||||
|
||||
@@ -639,22 +604,22 @@ public:
|
||||
} else {
|
||||
warp_mma_(
|
||||
accum,
|
||||
pipe_state.warp_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_frag_B_[warp_mma_k % 2],
|
||||
accum);
|
||||
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
}
|
||||
|
||||
// Except for the last warp-tile, all warp-tiles issue their share of
|
||||
// global->shared fragment copies
|
||||
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
|
||||
int group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
|
||||
int group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
|
||||
copy_tiles_and_advance_B(iterator_B, group_start_iteration_B);
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(mma_quant_args, stage);
|
||||
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
|
||||
column_wise_smem_ptr_B_, stage);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -663,15 +628,9 @@ public:
|
||||
// - moves to the next global fetch stage
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
|
||||
// Performs the last warp-tile's share of global->shared fragment copies
|
||||
if constexpr (Detail::AsyncCopyIterationsPerStageA >= Base::kWarpGemmIterations) {
|
||||
int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
|
||||
}
|
||||
int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
|
||||
if constexpr (Detail::AsyncCopyIterationsPerStageB >= Base::kWarpGemmIterations) {
|
||||
int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
||||
copy_tiles_and_advance_B(iterator_B, group_start_iteration_B);
|
||||
}
|
||||
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
@@ -680,66 +639,69 @@ public:
|
||||
gmem_wait();
|
||||
|
||||
// Move to the next global fetch stage
|
||||
advance_smem_write_stage(iterator_A, iterator_B);
|
||||
quant_params_accessor_B_.advance_smem_write_stage(mma_quant_args);
|
||||
|
||||
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
|
||||
advance_smem_read_stage();
|
||||
int byte_offset = quant_params_accessor_B_.advance_smem_read_stage();
|
||||
warp_dequantizer_.add_pointer_offset(byte_offset);
|
||||
|
||||
// Disable global fetching when done with global fetch iterations
|
||||
--gemm_k_iterations;
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
|
||||
}
|
||||
|
||||
// The last warp-tile also converts the shared memory fragments used by
|
||||
// the first warp-tile of the next iteration, if necessary (so we can
|
||||
// immediately start issuing MMA instructions at the top of the loop )
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
|
||||
warp_mma_.transform(
|
||||
pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2],
|
||||
pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2],
|
||||
pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2],
|
||||
pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform the specified number of threadblock mainloop iterations of matrix
|
||||
/// multiply-accumulate. Assumes prologue has been initiated.
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void gemm_iters(
|
||||
int gemm_k_iterations, ///< number of threadblock mainloop iterations
|
||||
FragmentC &accum, ///< [in|out] accumulator tile
|
||||
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
|
||||
QuantArguments &mma_quant_args)
|
||||
IteratorB &iterator_B,
|
||||
TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory
|
||||
{
|
||||
PipeState pipe_state;
|
||||
|
||||
// Unpack and dequant the first stage of B.
|
||||
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0);
|
||||
|
||||
// Disable global fetching if done with global fetch iterations
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0);
|
||||
|
||||
// Load first warp-tile's B fragment from shared memory
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
warp_dequantizer_.load(pipe_state.warp_frag_code_scale_,
|
||||
pipe_state.warp_frag_code_zp_,
|
||||
pipe_state.warp_frag_super_scale_);
|
||||
|
||||
warp_dequantizer_.load(pipe_state.warp_frag_local_scale_);
|
||||
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
|
||||
|
||||
// Load first warp-tile's A fragment from shared memory
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[0]);
|
||||
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
// Dequantize B to in register
|
||||
warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_,
|
||||
pipe_state.warp_frag_code_scale_,
|
||||
pipe_state.warp_frag_code_zp_,
|
||||
pipe_state.warp_frag_super_scale_,
|
||||
pipe_state.warp_loaded_frag_B_,
|
||||
pipe_state.warp_frag_B_[0],
|
||||
0,
|
||||
0);
|
||||
// Copy dequatized data to shared memory used by mma core.
|
||||
copy_tiles_and_advance_per_stage_B<false, true>(iterator_B);
|
||||
|
||||
if constexpr (Detail::kStagedAccumulation) {
|
||||
// Load first warp-tile's B fragment from shared memory
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
// Transform, if necessary, the first warp-tile's shared memory fragments
|
||||
warp_mma_.transform(
|
||||
pipe_state.warp_transformed_frag_A_[0],
|
||||
pipe_state.warp_transformed_frag_B_[0],
|
||||
pipe_state.warp_loaded_frag_A_[0],
|
||||
pipe_state.warp_loaded_frag_B_[0]);
|
||||
|
||||
if (Detail::kStagedAccumulation) {
|
||||
pipe_state.tmp_accum_.clear();
|
||||
}
|
||||
|
||||
@@ -753,13 +715,13 @@ public:
|
||||
accum,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
mma_quant_args,
|
||||
tile_dequanter_B,
|
||||
gemm_k_iterations,
|
||||
stage);
|
||||
stage += 1;
|
||||
}
|
||||
|
||||
if constexpr (Detail::kStagedAccumulation) {
|
||||
if (Detail::kStagedAccumulation) {
|
||||
plus<FragmentC> plus_accum;
|
||||
accum = plus_accum(accum, pipe_state.tmp_accum_);
|
||||
}
|
||||
@@ -799,12 +761,14 @@ public:
|
||||
else
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
|
||||
//this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0});
|
||||
}
|
||||
smem_read_stage_idx_ = smem_write_stage_idx_;
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory.
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
@@ -815,13 +779,13 @@ public:
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
///< iterators for extra quant params for B
|
||||
QuantArguments mma_quant_args,
|
||||
///< pre-load and dequantize B to shared memory
|
||||
TileDequanterB tile_dequanter_B,
|
||||
///< initial value of accumulator
|
||||
FragmentC const &src_accum) {
|
||||
|
||||
// Prologue (start fetching iterations of global fragments into shared memory)
|
||||
prologue(iterator_A, iterator_B, mma_quant_args, gemm_k_iterations);
|
||||
prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations);
|
||||
|
||||
// Wait until we have at least one completed global fetch stage
|
||||
gmem_wait();
|
||||
@@ -830,7 +794,7 @@ public:
|
||||
accum = src_accum;
|
||||
|
||||
// Perform the MAC-iterations
|
||||
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, mma_quant_args);
|
||||
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B);
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -1,315 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/memory_sm80.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
template <
|
||||
/// Original data type
|
||||
typename T,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterators over super scales in global memory
|
||||
typename IteratorSuperScale_,
|
||||
/// Iterators over super scales in shared memory
|
||||
typename SmemIteratorSuperScale_,
|
||||
/// Iterators over local scales in global memory
|
||||
typename IteratorLocalScale_,
|
||||
/// Iterators over local scales in shared memory
|
||||
typename SmemIteratorLocalScale_,
|
||||
/// Iterators over code scales and zps in global memory
|
||||
typename IteratorCodeScaleZp_,
|
||||
/// Iterators over code scales and zps in shared memory
|
||||
typename SmemIteratorCodeScaleZp_,
|
||||
/// Number of stages,
|
||||
int Stages_,
|
||||
/// Group size for quantization
|
||||
int GroupSize_>
|
||||
class Wint2ParamsAccessor {
|
||||
public:
|
||||
static_assert(platform::is_same<T, half_t>::value || platform::is_same<T, bfloat16_t>::value,
|
||||
"T must be fp16 or bf16");
|
||||
|
||||
using ElementType = T;
|
||||
using Shape = Shape_;
|
||||
|
||||
using IteratorSuperScale = IteratorSuperScale_;
|
||||
using SmemIteratorSuperScale = SmemIteratorSuperScale_;
|
||||
|
||||
using IteratorLocalScale = IteratorLocalScale_;
|
||||
using SmemIteratorLocalScale = SmemIteratorLocalScale_;
|
||||
|
||||
using IteratorCodeScaleZp = IteratorCodeScaleZp_;
|
||||
using SmemIteratorCodeScaleZp = SmemIteratorCodeScaleZp_;
|
||||
|
||||
constexpr static int kStages = Stages_;
|
||||
constexpr static int kGroupSize = GroupSize_;
|
||||
|
||||
using ElementSuperScale = typename IteratorSuperScale::Element;
|
||||
using LayoutSuperScale = typename IteratorSuperScale::Layout;
|
||||
|
||||
/// local_scale uint4 and group-wise
|
||||
using ElementLocalScale = typename IteratorLocalScale::Element;
|
||||
using LayoutLocalScale = typename IteratorLocalScale::Layout;
|
||||
static_assert(platform::is_same<ElementLocalScale, uint4b_t>::value,
|
||||
"local_scale's type must be uint4b_t.");
|
||||
|
||||
using ElementCodeScaleZp = typename IteratorCodeScaleZp::Element;
|
||||
using LayoutCodeScaleZp = typename IteratorCodeScaleZp::Layout;
|
||||
|
||||
/// 2 uint4b_t values are stored in a single uint8_t
|
||||
constexpr static int kStagesPerLocalScaleLoad = 2 * kGroupSize / Shape::kK;
|
||||
constexpr static int kLocalScaleRows =
|
||||
IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn * sizeof_bits<ElementLocalScale>::value / 8 / Shape::kN;
|
||||
|
||||
using SmemElement = uint8_t;
|
||||
constexpr static int kSmemRows =
|
||||
kLocalScaleRows * kStages + sizeof(ElementSuperScale) + sizeof(ElementCodeScaleZp) * 2;
|
||||
constexpr static int kSmemColumns = Shape::kN;
|
||||
|
||||
using QuantParamsShape = MatrixShape<kSmemRows, kSmemColumns>;
|
||||
|
||||
constexpr static int kSuperScaleSmemOffset = 0;
|
||||
constexpr static int kCodeScaleSmemOffset = kSmemColumns * sizeof(ElementSuperScale);
|
||||
constexpr static int kCodeZpSmemOffset = kCodeScaleSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp);
|
||||
constexpr static int kLocalScaleSmemOffset = kCodeZpSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp);
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using SuperTensorRef = cutlass::TensorRef<ElementSuperScale, LayoutSuperScale>;
|
||||
using LocalTensorRef = cutlass::TensorRef<ElementLocalScale, LayoutLocalScale>;
|
||||
using CodeTensorRef = cutlass::TensorRef<ElementCodeScaleZp, LayoutCodeScaleZp>;
|
||||
|
||||
struct Arguments {
|
||||
IteratorSuperScale iterator_super_scale;
|
||||
IteratorLocalScale iterator_local_scale;
|
||||
IteratorCodeScaleZp iterator_code_scale;
|
||||
IteratorCodeScaleZp iterator_code_zp;
|
||||
|
||||
int local_scale_pointer_offset;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Arguments(IteratorSuperScale iterator_super_scale,
|
||||
IteratorLocalScale iterator_local_scale,
|
||||
IteratorCodeScaleZp iterator_code_scale,
|
||||
IteratorCodeScaleZp iterator_code_zp,
|
||||
int local_scale_pointer_offset)
|
||||
: iterator_super_scale(iterator_super_scale),
|
||||
iterator_local_scale(iterator_local_scale),
|
||||
iterator_code_scale(iterator_code_scale),
|
||||
iterator_code_zp(iterator_code_zp),
|
||||
local_scale_pointer_offset(local_scale_pointer_offset) {}
|
||||
};
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Begin address of shared memory
|
||||
uint8_t* smem_pointer_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of super scale operand to shared memory
|
||||
SmemIteratorSuperScale smem_iterator_super_scale_;
|
||||
/// Iterator to write threadblock-scoped tile of local scale operand to shared memory
|
||||
SmemIteratorLocalScale smem_iterator_local_scale_;
|
||||
/// Iterator to write threadblock-scoped tile of code scale operand to shared memory
|
||||
SmemIteratorCodeScaleZp smem_iterator_code_scale_;
|
||||
/// Iterator to write threadblock-scoped tile of code zp operand to shared memory
|
||||
SmemIteratorCodeScaleZp smem_iterator_code_zp_;
|
||||
|
||||
/// Shared memory write stage index
|
||||
int smem_write_stage_idx_;
|
||||
|
||||
/// Shared memory read stage index
|
||||
int smem_read_stage_idx_;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementSuperScale* get_super_scale_smem_ptr() {
|
||||
return reinterpret_cast<ElementSuperScale*>(smem_pointer_ + kSuperScaleSmemOffset);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLocalScale* get_local_scale_smem_ptr() {
|
||||
return reinterpret_cast<ElementLocalScale*>(smem_pointer_ + kLocalScaleSmemOffset);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementCodeScaleZp* get_code_scale_smem_ptr() {
|
||||
return reinterpret_cast<ElementCodeScaleZp*>(smem_pointer_ + kCodeScaleSmemOffset);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementCodeScaleZp* get_code_zp_smem_ptr() {
|
||||
return reinterpret_cast<ElementCodeScaleZp*>(smem_pointer_ + kCodeZpSmemOffset);
|
||||
}
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
Wint2ParamsAccessor(
|
||||
///< prointer of shared memory
|
||||
uint8_t* smem_pointer,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: smem_pointer_(smem_pointer),
|
||||
smem_iterator_super_scale_(LayoutSuperScale(IteratorSuperScale::Shape::kColumn),
|
||||
get_super_scale_smem_ptr(), {1, IteratorSuperScale::Shape::kColumn}, thread_idx),
|
||||
smem_iterator_local_scale_(LayoutLocalScale(IteratorLocalScale::Shape::kColumn),
|
||||
get_local_scale_smem_ptr(), {1, IteratorLocalScale::Shape::kColumn}, thread_idx),
|
||||
smem_iterator_code_scale_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn),
|
||||
get_code_scale_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx),
|
||||
smem_iterator_code_zp_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn),
|
||||
get_code_zp_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx),
|
||||
smem_write_stage_idx_(0),
|
||||
smem_read_stage_idx_(0) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
SuperTensorRef super_scale_ref() {
|
||||
return {get_super_scale_smem_ptr(), LayoutSuperScale(IteratorSuperScale::Shape::kColumn)};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
LocalTensorRef local_scale_ref() {
|
||||
return {get_local_scale_smem_ptr(), LayoutLocalScale(IteratorLocalScale::Shape::kColumn)};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CodeTensorRef code_scale_ref() {
|
||||
return {get_code_scale_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CodeTensorRef code_zp_ref() {
|
||||
return {get_code_zp_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)};
|
||||
}
|
||||
|
||||
template <bool IsFirstStage>
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_per_stage(Arguments &quant_args, int stage) {
|
||||
if constexpr (IsFirstStage) {
|
||||
// Load channel-wise super_scale to shared memory, which only needs to be done once.
|
||||
typename IteratorSuperScale::Fragment tb_frag_super_scale;
|
||||
tb_frag_super_scale.clear();
|
||||
quant_args.iterator_super_scale.load(tb_frag_super_scale);
|
||||
this->smem_iterator_super_scale_.store(tb_frag_super_scale);
|
||||
|
||||
// Load channel-wise code_scale to shared memory, which only needs to be done once.
|
||||
typename IteratorCodeScaleZp::Fragment tb_frag_code_scale;
|
||||
tb_frag_code_scale.clear();
|
||||
quant_args.iterator_code_scale.load(tb_frag_code_scale);
|
||||
this->smem_iterator_code_scale_.store(tb_frag_code_scale);
|
||||
|
||||
// Load channel-wise code_zp to shared memory, which only needs to be done once.
|
||||
typename IteratorCodeScaleZp::Fragment tb_frag_code_zp;
|
||||
tb_frag_code_zp.clear();
|
||||
quant_args.iterator_code_zp.load(tb_frag_code_zp);
|
||||
this->smem_iterator_code_zp_.store(tb_frag_code_zp);
|
||||
}
|
||||
|
||||
if ((stage % kStagesPerLocalScaleLoad) == 0) {
|
||||
// Load group-wise local_scale to shared memory, which only needs to be done at each stage.
|
||||
// Since 2 uint4b_t values of local_scale are saved in a single uint8_t, local_scale needs to be loaded once every two stages.
|
||||
using AccessType = typename IteratorLocalScale::AccessType;
|
||||
cutlass::arch::CacheOperation::Kind const kCacheOp = (sizeof_bits<AccessType>::value == 128)
|
||||
? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always;
|
||||
|
||||
quant_args.iterator_local_scale.set_iteration_index(0);
|
||||
this->smem_iterator_local_scale_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for local_scale
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < IteratorLocalScale::ThreadMap::Iterations::kCount; ++j) {
|
||||
AccessType *dst_ptr =
|
||||
reinterpret_cast<AccessType *>(this->smem_iterator_local_scale_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorLocalScale::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = quant_args.iterator_local_scale.get();
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorLocalScale::Element>::value *
|
||||
IteratorLocalScale::ThreadMap::kElementsPerAccess /
|
||||
IteratorLocalScale::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOp>(
|
||||
dst_ptr + v, gmem_ptr, quant_args.iterator_local_scale.valid());
|
||||
}
|
||||
++quant_args.iterator_local_scale;
|
||||
}
|
||||
++this->smem_iterator_local_scale_;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void advance_smem_write_stage(Arguments &quant_args) {
|
||||
if (smem_write_stage_idx_ % kStagesPerLocalScaleLoad == 0) {
|
||||
// Advance global iterators
|
||||
quant_args.iterator_local_scale.add_pointer_offset(quant_args.local_scale_pointer_offset);
|
||||
|
||||
// Advance shared iterators
|
||||
int smem_pointer_offset = IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn;
|
||||
smem_iterator_local_scale_.add_pointer_offset(smem_pointer_offset);
|
||||
}
|
||||
|
||||
// Increment shared memory write stage index
|
||||
++smem_write_stage_idx_;
|
||||
|
||||
if (smem_write_stage_idx_ == kStagesPerLocalScaleLoad * kStages) {
|
||||
// Wrap back around to the 'start' of the circular buffer in shared memory
|
||||
int pointer_offset = - kStages * IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn;
|
||||
smem_iterator_local_scale_.add_pointer_offset(pointer_offset);
|
||||
smem_write_stage_idx_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
int advance_smem_read_stage() {
|
||||
int byte_offset = 0;
|
||||
|
||||
++smem_read_stage_idx_;
|
||||
|
||||
if (smem_read_stage_idx_ % kStagesPerLocalScaleLoad == 0) {
|
||||
byte_offset = kLocalScaleRows * kSmemColumns;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx_ == kStagesPerLocalScaleLoad * kStages) {
|
||||
smem_read_stage_idx_ = 0;
|
||||
byte_offset = - (kStages - 1) * kLocalScaleRows * kSmemColumns;
|
||||
}
|
||||
|
||||
return byte_offset;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
int clear_mask(Arguments &quant_args, bool cond) {
|
||||
quant_args.iterator_local_scale.clear_mask(cond);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
@@ -0,0 +1,130 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm_coord.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_unzip.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
template <typename ElementT, typename ScaleElementT, int Rows, int Columns,
|
||||
int Stages, int NumThreads, WintQuantMethod Method>
|
||||
struct TileDequanter {
|
||||
using WeightQuantTraits = WintQuantTraits<ElementT, Method>;
|
||||
using MmaElementT = typename WeightQuantTraits::MmaWeightType;
|
||||
using QuantArguments = typename WeightQuantTraits::Arguments;
|
||||
|
||||
using UnzipAndDequantFunctor =
|
||||
UnzipAndDequantFunctor<MmaElementT, Method, Rows, Columns, NumThreads>;
|
||||
|
||||
static constexpr bool kUseSharedMemory = true;
|
||||
|
||||
static constexpr int kRows = Rows;
|
||||
static constexpr int kColumns = Columns;
|
||||
static constexpr int kStages = Stages;
|
||||
|
||||
MmaElementT *out_smem_ptr{nullptr};
|
||||
|
||||
char *pointer{nullptr};
|
||||
int64_t ldm{0};
|
||||
cutlass::MatrixCoord tb_offset;
|
||||
cutlass::MatrixCoord extent;
|
||||
|
||||
ScaleElementT *super_scale_ptr{nullptr};
|
||||
cutlass::MatrixCoord tb_offset_scale;
|
||||
|
||||
QuantArguments quant_args;
|
||||
|
||||
int64_t block_start_rows[kStages];
|
||||
bool need_preload{true};
|
||||
UnzipAndDequantFunctor unzip_functor;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
TileDequanter(MmaElementT *out_smem_ptr, char *pointer, int64_t ldm,
|
||||
const cutlass::MatrixCoord &extent,
|
||||
const cutlass::MatrixCoord &tb_offset,
|
||||
ScaleElementT *super_scale_ptr,
|
||||
const cutlass::MatrixCoord &tb_offset_scale,
|
||||
const QuantArguments &quant_args)
|
||||
: out_smem_ptr(out_smem_ptr), pointer(pointer), ldm(ldm), extent(extent),
|
||||
tb_offset(tb_offset), super_scale_ptr(super_scale_ptr),
|
||||
tb_offset_scale(tb_offset_scale), quant_args(quant_args) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaElementT *GetOutPtr() { return out_smem_ptr; }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void AddTileOffset(const cutlass::MatrixCoord &tile_offset) {
|
||||
tb_offset.row() += tile_offset.row() * kRows;
|
||||
tb_offset.column() += tile_offset.column() * kColumns;
|
||||
tb_offset_scale.column() += tile_offset.column() * kColumns;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void Load(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
|
||||
int zipped_row = WeightQuantTraits::CaclPackedDim(tb_offset.row());
|
||||
if (tb_offset.row() >= extent.row() ||
|
||||
tb_offset.column() >= extent.column()) {
|
||||
return;
|
||||
}
|
||||
|
||||
block_start_rows[stage % kStages] = tb_offset.row();
|
||||
|
||||
using ZippedT = typename WeightQuantTraits::WeightType;
|
||||
ZippedT *in_ptr = reinterpret_cast<ZippedT *>(pointer) + zipped_row * ldm +
|
||||
tb_offset.column();
|
||||
ScaleElementT *scale_ptr = super_scale_ptr + tb_offset_scale.column();
|
||||
|
||||
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
|
||||
const uint8_t *local_scale_ptr = quant_args.local_scale_ptr +
|
||||
(tb_offset.row() / 128) * ldm +
|
||||
tb_offset_scale.column();
|
||||
const float *code_scale_ptr =
|
||||
quant_args.code_scale_ptr + tb_offset_scale.column();
|
||||
const float *code_zp_ptr =
|
||||
quant_args.code_zp_ptr + tb_offset_scale.column();
|
||||
|
||||
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
|
||||
unzip_functor.LoadAsync(in_ptr, local_scale_ptr, code_scale_ptr, code_zp_ptr,
|
||||
scale_ptr, &args, ldm, need_preload);
|
||||
need_preload = false;
|
||||
} else {
|
||||
// CUTLASS_TRACE_DEVICE("Not Supported!");
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void UnpackAndDequant(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
|
||||
int64_t block_start_row = block_start_rows[stage % kStages];
|
||||
if (block_start_row >= extent.row()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
|
||||
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
|
||||
unzip_functor.ComputeVectorized(args, out_smem_ptr, block_start_row);
|
||||
} else {
|
||||
// CUTLASS_TRACE_DEVICE("Not Supported!");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
@@ -41,9 +41,12 @@
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace warp {
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace warp
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -78,7 +81,7 @@ private:
|
||||
// Shape for computing the FP16s
|
||||
using ComputeInstructionShape = InstructionShape_;
|
||||
|
||||
// Chosen so we get K=16 for int8, K=32 for int4, K=64 for int2.
|
||||
// Chosen so we get K=16 for int8 and K=32 for int4.
|
||||
static constexpr int LoadInstructionK = 128 / sizeof_bits<ElementB>::value;
|
||||
|
||||
// Shape for loading the narrow data type from shared memory
|
||||
|
@@ -58,12 +58,15 @@
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace warp {
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace warp
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Structure to compute the matrix product targeting Tensor Cores, for the case when A is floating point and B is quantized integer.
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
@@ -294,235 +297,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Structure to compute the matrix product targeting Tensor Cores, for the case when A is floating point and B is quantized integer.
|
||||
/// Specialization for B of uint2b_t.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Data type of A elements
|
||||
typename ElementA_,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA_,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB_,
|
||||
/// Element type of C matrix
|
||||
typename ElementC_,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC_,
|
||||
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
|
||||
typename Policy_,
|
||||
/// Instruction shape to override shared memory iterators with
|
||||
typename SharedMemoryInstructionShape_,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK_,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor>
|
||||
class MmaTensorOpComputeBWithF16<
|
||||
Shape_,
|
||||
ElementA_,
|
||||
LayoutA_,
|
||||
uint2b_t,
|
||||
LayoutB_,
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
Policy_,
|
||||
SharedMemoryInstructionShape_,
|
||||
PartitionsK_,
|
||||
AccumulatorsInRowMajor>
|
||||
{
|
||||
public:
|
||||
/// Shape of warp-level matrix operation (concept: GemmShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Data type of multiplicand A
|
||||
using ElementA = ElementA_;
|
||||
|
||||
/// Layout of multiplicand A
|
||||
using LayoutA = LayoutA_;
|
||||
|
||||
/// Data type of multiplicand B
|
||||
using ElementB = uint2b_t;
|
||||
|
||||
/// Layout of multiplicand B
|
||||
using LayoutB = LayoutB_;
|
||||
|
||||
/// Data type of accumulator matrix C
|
||||
using ElementC = ElementC_;
|
||||
|
||||
/// Layout of accumulator matrix C
|
||||
using LayoutC = LayoutC_;
|
||||
|
||||
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
|
||||
using Policy = Policy_;
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
using ArchMmaOperator = typename Policy::Operator;
|
||||
|
||||
/// Indicates math operator
|
||||
using MathOperator = typename ArchMmaOperator::Operator;
|
||||
|
||||
/// Architecture tag from underlying instruction
|
||||
using ArchTag = typename ArchMmaOperator::ArchTag;
|
||||
static_assert((platform::is_same<typename ArchMmaOperator::ElementA, half_t>::value
|
||||
&& platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value)
|
||||
|| (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value
|
||||
&& platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value
|
||||
&& ArchTag::kMinComputeCapability >= 80),
|
||||
"MmaTensorOpCvtBToA only supports underlying HMMA/QMMA");
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value
|
||||
|| (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80),
|
||||
"MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+");
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
|
||||
/// Shape of underlying instruction
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
/// Instruction shape to override shared memory iterators with
|
||||
using SharedMemoryInstructionShape = SharedMemoryInstructionShape_;
|
||||
|
||||
static_assert(
|
||||
SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load");
|
||||
static_assert(
|
||||
SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load");
|
||||
|
||||
static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK;
|
||||
|
||||
static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), "");
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
/// Number of threads participating in warp-level matrix product
|
||||
static int const kThreadCount = 32;
|
||||
|
||||
/// Number of partitions along K dimension
|
||||
static int const kPartitionsK = PartitionsK_;
|
||||
|
||||
public:
|
||||
/// Iterates over the A operand in memory
|
||||
using IteratorA
|
||||
= MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
|
||||
|
||||
/// Storage for A tile
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Storage for transformed A tile
|
||||
using TransformedFragmentA = Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
|
||||
|
||||
/// Iterates over the B operand in memory
|
||||
using IteratorB = MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB,
|
||||
LayoutB, MatrixShape<SharedMemoryInstructionShape::kK, InstructionShape::kN>, Policy::OpDelta::kRow,
|
||||
kThreadCount, kPartitionsK>;
|
||||
|
||||
/// Storage for B tile
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Storage for transformed B tile
|
||||
using TransformedFragmentB =
|
||||
Array<typename ArchMmaOperator::ElementB, FragmentB::kElements / kExpansionFactor>;
|
||||
|
||||
/// Iterates over the C operand in memory
|
||||
using IteratorC = MmaTensorOpAccumulatorTileIterator<MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
|
||||
typename ArchMmaOperator::Shape, typename Policy::OpDelta>;
|
||||
|
||||
/// Storage for C tile
|
||||
using FragmentC = typename IteratorC::Fragment;
|
||||
|
||||
/// Number of mma operations performed
|
||||
using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
|
||||
(Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>;
|
||||
|
||||
public:
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
ArchMmaOperator mma;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpComputeBWithF16() {}
|
||||
|
||||
/// Performs a warp-level matrix multiply-accumulate operation
|
||||
CUTLASS_DEVICE
|
||||
void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C) const
|
||||
{
|
||||
|
||||
using MmaOperandA = typename ArchMmaOperator::FragmentA;
|
||||
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using MmaOperandC = typename ArchMmaOperator::FragmentC;
|
||||
|
||||
D = C;
|
||||
|
||||
MmaOperandA const* ptr_A = reinterpret_cast<MmaOperandA const*>(&A);
|
||||
MmaOperandB const* ptr_B = reinterpret_cast<MmaOperandB const*>(&B);
|
||||
MmaOperandC* ptr_D = reinterpret_cast<MmaOperandC*>(&D);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
|
||||
// Serpentine visitation order maximizing reuse of Rb
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < MmaIterations::kColumn; ++n)
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < MmaIterations::kRow; ++m)
|
||||
{
|
||||
|
||||
int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
|
||||
|
||||
if (AccumulatorsInRowMajor)
|
||||
{ // matrix B is reordered
|
||||
mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n],
|
||||
ptr_D[n + m_serpentine * MmaIterations::kColumn]);
|
||||
}
|
||||
else
|
||||
{
|
||||
mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n],
|
||||
ptr_D[m_serpentine + n * MmaIterations::kRow]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
// Serpentine visitation order maximizing reuse of Ra
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < MmaIterations::kRow; ++m)
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < MmaIterations::kColumn; ++n)
|
||||
{
|
||||
|
||||
int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
|
||||
|
||||
if (AccumulatorsInRowMajor)
|
||||
{ // matrix B is reordered
|
||||
mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine],
|
||||
ptr_D[n_serpentine + m * MmaIterations::kColumn]);
|
||||
}
|
||||
else
|
||||
{
|
||||
mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine],
|
||||
ptr_D[m + n_serpentine * MmaIterations::kRow]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
|
@@ -1,442 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterators used by warp-level matrix multiply operations
|
||||
targeting Tensor Cores.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace warp {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<bfloat16_t> {
|
||||
using Type = __nv_bfloat16;
|
||||
using DualType = __nv_bfloat162;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<half_t> {
|
||||
using Type = __half;
|
||||
using DualType = __half2;
|
||||
};
|
||||
|
||||
template <typename T, int N, typename Enable = void>
|
||||
struct LocalScaleConverter {
|
||||
using FragmentSource = Array<uint8_t, N>;
|
||||
using FragmentResult = Array<T, N>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void Apply(FragmentSource const& local_scale_frag,
|
||||
FragmentResult const& super_scale_frag,
|
||||
FragmentResult& scale_frag,
|
||||
int shift_bit) {
|
||||
constexpr uint32_t kLocalScaleMask = 0xf;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
int32_t shifted_value = (static_cast<int32_t>(local_scale_frag[i]) >> shift_bit) & kLocalScaleMask;
|
||||
scale_frag[i] = static_cast<T>(shifted_value) * super_scale_frag[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct LocalScaleConverter<half_t, N, typename platform::enable_if<N % 4 == 0>::type> {
|
||||
using FragmentSource = Array<uint8_t, N>;
|
||||
using FragmentResult = Array<half_t, N>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void Apply(FragmentSource const& local_scale_frag,
|
||||
FragmentResult const& super_scale_frag,
|
||||
FragmentResult& scale_frag,
|
||||
int shift_bit) {
|
||||
constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
constexpr uint32_t MASK = 0x000f000f;
|
||||
// 2^10 = 1024
|
||||
constexpr uint32_t I4s_TO_FP16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
// -2^10 = -1024
|
||||
constexpr uint32_t FP16_BIAS = 0xE400E400;
|
||||
// 1.0
|
||||
constexpr uint32_t FP16_ONE = 0x3C003C00;
|
||||
|
||||
__half2* scale_ptr = reinterpret_cast<__half2 *>(&scale_frag);
|
||||
__half2 const* super_scale_ptr = reinterpret_cast<__half2 const*>(&super_scale_frag);
|
||||
|
||||
uint32_t const* local_scale_ptr = reinterpret_cast<uint32_t const*>(&local_scale_frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 4; ++i) {
|
||||
int i4s = local_scale_ptr[i] >> shift_bit;
|
||||
|
||||
// unpack: 0, 1
|
||||
int32_t low = __byte_perm(i4s, i4s, 0xF1F0);
|
||||
int32_t unpack0 = lop3<immLut>(low, MASK, I4s_TO_FP16s_MAGIC_NUM);
|
||||
// unpack: 2, 3
|
||||
int32_t high = __byte_perm(i4s, i4s, 0xF3F2);
|
||||
int32_t unpack1 = lop3<immLut>(high, MASK, I4s_TO_FP16s_MAGIC_NUM);
|
||||
|
||||
__half2 scale0 = __hfma2(*reinterpret_cast<__half2*>(&unpack0),
|
||||
*reinterpret_cast<const __half2*>(&FP16_ONE),
|
||||
*reinterpret_cast<const __half2*>(&FP16_BIAS));
|
||||
__half2 scale1 = __hfma2(*reinterpret_cast<__half2*>(&unpack1),
|
||||
*reinterpret_cast<const __half2*>(&FP16_ONE),
|
||||
*reinterpret_cast<const __half2*>(&FP16_BIAS));
|
||||
|
||||
scale_ptr[2 * i] = __hmul2(scale0, super_scale_ptr[2 * i]);
|
||||
scale_ptr[2 * i + 1] = __hmul2(scale1, super_scale_ptr[2 * i + 1]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct LocalScaleConverter<bfloat16_t, N, typename platform::enable_if<N % 4 == 0>::type> {
|
||||
using FragmentSource = Array<uint8_t, N>;
|
||||
using FragmentResult = Array<bfloat16_t, N>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void Apply(FragmentSource const& local_scale_frag,
|
||||
FragmentResult const& super_scale_frag,
|
||||
FragmentResult& scale_frag,
|
||||
int shift_bit) {
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||
constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA;
|
||||
constexpr uint32_t MASK = 0x000F000F;
|
||||
constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
|
||||
|
||||
constexpr uint32_t BF16_BIAS = 0xC300C300;
|
||||
constexpr uint32_t BF16_ONE = 0x3F803F80;
|
||||
|
||||
__nv_bfloat162* scale_ptr = reinterpret_cast<__nv_bfloat162 *>(&scale_frag);
|
||||
__nv_bfloat162 const* super_scale_ptr = reinterpret_cast<__nv_bfloat162 const*>(&super_scale_frag);
|
||||
|
||||
uint32_t const* local_scale_ptr = reinterpret_cast<uint32_t const*>(&local_scale_frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 4; ++i) {
|
||||
int i4s = local_scale_ptr[i] >> shift_bit;
|
||||
|
||||
// unpack: 0, 1
|
||||
int32_t low = __byte_perm(i4s, i4s, 0xF1F0);
|
||||
int32_t unpack0 = lop3<immLut>(low, MASK, I4s_TO_BF16s_MAGIC_NUM);
|
||||
// unpack: 2, 3
|
||||
int32_t high = __byte_perm(i4s, i4s, 0xF3F2);
|
||||
int32_t unpack1 = lop3<immLut>(high, MASK, I4s_TO_BF16s_MAGIC_NUM);
|
||||
|
||||
nv_bfloat162 scale0 = __hfma2(*reinterpret_cast<nv_bfloat162*>(&unpack0),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&BF16_ONE),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&BF16_BIAS));
|
||||
nv_bfloat162 scale1 = __hfma2(*reinterpret_cast<nv_bfloat162*>(&unpack1),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&BF16_ONE),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&BF16_BIAS));
|
||||
|
||||
scale_ptr[2 * i] = __hmul2(scale0, super_scale_ptr[2 * i]);
|
||||
scale_ptr[2 * i + 1] = __hmul2(scale1, super_scale_ptr[2 * i + 1]);
|
||||
}
|
||||
#else
|
||||
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
|
||||
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
|
||||
// numerous conversion instructions in GEMM main loop.
|
||||
arch::device_breakpoint();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Matrix multiply operator
|
||||
typename MmaOperator_,
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Operand identity
|
||||
Operand Operand,
|
||||
/// Data type of Scale elements
|
||||
typename ElementOperand_,
|
||||
/// Layout of operand
|
||||
typename Layout_,
|
||||
/// Group size for quantization
|
||||
int GroupSize_,
|
||||
///
|
||||
typename Enable = void>
|
||||
class MmaTensorOpWin2xDequantizer {
|
||||
//static_assert(false, "Not Supported!");
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat specialization for Ampere
|
||||
template <
|
||||
/// Underlying matrix multiply operator (concept: MmaTensorOp)
|
||||
typename MmaOperator_,
|
||||
/// Shape of the warp level matrix multiply (concept: GemmShape)
|
||||
typename Shape_,
|
||||
/// Data type of Scale elements
|
||||
typename ElementOperand_,
|
||||
/// Group size for quantization
|
||||
int GroupSize_>
|
||||
class MmaTensorOpWin2xDequantizer<
|
||||
MmaOperator_,
|
||||
Shape_,
|
||||
Operand::kB,
|
||||
ElementOperand_,
|
||||
layout::RowMajor,
|
||||
GroupSize_>
|
||||
//typename platform::enable_if<MmaOperator_::ArchTag::kMinComputeCapability >= 80
|
||||
// && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type>
|
||||
{
|
||||
public:
|
||||
static_assert(platform::is_same<ElementOperand_, half_t>::value || platform::is_same<ElementOperand_, bfloat16_t>::value,
|
||||
"T must be fp16 or bf16");
|
||||
|
||||
/// Mma Operator
|
||||
using MmaOperator = MmaOperator_;
|
||||
|
||||
// The architecture specific mma ooperator being used
|
||||
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
|
||||
|
||||
// Mma Instruction Shape
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
/// Warp mma shape
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Type of mma operand
|
||||
using ElementOperand = ElementOperand_;
|
||||
|
||||
/// Layout of the scales in shared memory
|
||||
using Layout = layout::RowMajor;
|
||||
|
||||
/// Group size for quantization
|
||||
static constexpr int kGroupSize = GroupSize_;
|
||||
|
||||
/// Type of input
|
||||
using ElementB = typename MmaOperator::FragmentB::Element;
|
||||
static_assert(platform::is_same<ElementB, uint2b_t>::value, "ElementB must be uint2b_t");
|
||||
|
||||
/// Type of the scales
|
||||
using ElementLocalScale = uint4b_t;
|
||||
using ElementSuperScale = ElementOperand;
|
||||
using ElementCodeScaleZp = float;
|
||||
|
||||
// Fragment to hold scale data to apply to B before mma
|
||||
// We need 1 fp16 per matrix iteration in the N dimension
|
||||
static constexpr int kWarpIterationsAlongN = MmaOperator::MmaIterations::kColumn;
|
||||
|
||||
// use uint8_t to save 2 4-bits local scales
|
||||
using FragmentLocalScale = Array<uint8_t, kWarpIterationsAlongN>;
|
||||
using FragmentSuperScale = Array<ElementSuperScale, kWarpIterationsAlongN>;
|
||||
using FragmentCodeScaleZp = Array<ElementCodeScaleZp, kWarpIterationsAlongN>;
|
||||
|
||||
/// Fragment to hold B data before Mma
|
||||
using FragmentInput = Array<ElementB, MmaOperator::FragmentB::kElements>;
|
||||
|
||||
// This is the ratio of the load instruction vs the compute instruction.
|
||||
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
|
||||
|
||||
static constexpr int kNumPacks = sizeof_bits<uint8_t>::value / sizeof_bits<ElementB>::value;
|
||||
static constexpr int kUnpackFactor = MmaOperator::FragmentB::kElements / (kWarpIterationsAlongN * kNumPacks);
|
||||
static constexpr int kUnpackInterval = kExpansionFactor / kUnpackFactor;
|
||||
|
||||
/// Unpack 4 uint2b_t values compreseed in a uint8_t to floating points.
|
||||
using Uint2Converter = FastInterleavedAndBiasedNumericArrayConverter<
|
||||
ElementOperand, ElementB, MmaOperator::FragmentB::kElements / kUnpackFactor>;
|
||||
using FragmentInputUnpack = typename Uint2Converter::result_type;
|
||||
|
||||
/// Fragment to hold internal scales before Mma
|
||||
using FragmentScale = Array<ElementOperand, FragmentLocalScale::kElements>;
|
||||
|
||||
/// Fragment of dequantized B
|
||||
using FragmentOutput = Array<ElementOperand, MmaOperator::FragmentB::kElements / kExpansionFactor>;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using SuperTensorRef = cutlass::TensorRef<ElementSuperScale, Layout>;
|
||||
using LocalTensorRef = cutlass::TensorRef<ElementLocalScale, Layout>;
|
||||
using CodeTensorRef = cutlass::TensorRef<ElementCodeScaleZp, Layout>;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
uint8_t* pointer_local_scale_;
|
||||
ElementCodeScaleZp* pointer_code_scale_;
|
||||
ElementCodeScaleZp* pointer_code_zp_;
|
||||
ElementSuperScale* pointer_super_scale_;
|
||||
|
||||
//FragmentInputUnpack unpacked_frag_;
|
||||
FragmentScale scale_frag_;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpWin2xDequantizer(SuperTensorRef smem_super_scale,
|
||||
LocalTensorRef smem_local_scale,
|
||||
CodeTensorRef smem_code_scale,
|
||||
CodeTensorRef smem_code_zp,
|
||||
int warp_idx_n,
|
||||
int lane_idx) {
|
||||
int warp_offset = warp_idx_n * Shape::kN;
|
||||
int quad = lane_idx / 4;
|
||||
int thread_offset = warp_offset + quad;
|
||||
pointer_super_scale_ = smem_super_scale.data() + thread_offset;
|
||||
pointer_code_scale_ = smem_code_scale.data() + thread_offset;
|
||||
pointer_code_zp_ = smem_code_zp.data() + thread_offset;
|
||||
pointer_local_scale_ = reinterpret_cast<uint8_t *>(smem_local_scale.data()) + thread_offset;
|
||||
}
|
||||
|
||||
/// Channel-wise params, need to load just once
|
||||
CUTLASS_DEVICE
|
||||
void load(FragmentCodeScaleZp& code_scale_frag,
|
||||
FragmentCodeScaleZp& code_zp_frag,
|
||||
FragmentSuperScale& super_scale_frag) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
|
||||
super_scale_frag[mma_n_iter] = pointer_super_scale_[mma_n_iter * InstructionShape::kN]; // bank conflict
|
||||
code_scale_frag[mma_n_iter] = pointer_code_scale_[mma_n_iter * InstructionShape::kN];
|
||||
code_zp_frag[mma_n_iter] = pointer_code_zp_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
|
||||
/// Group-wise params, need to load multiple times
|
||||
CUTLASS_DEVICE
|
||||
void load(FragmentLocalScale& local_scale_frag) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
|
||||
local_scale_frag[mma_n_iter] = pointer_local_scale_[mma_n_iter * InstructionShape::kN]; // bank conflict
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(const FragmentLocalScale& local_scale_frag,
|
||||
const FragmentCodeScaleZp& code_scale_frag,
|
||||
const FragmentCodeScaleZp& code_zp_frag,
|
||||
const FragmentSuperScale& super_scale_frag,
|
||||
const FragmentInput& input_frag,
|
||||
FragmentOutput& output_frag,
|
||||
int tb_offset_k,
|
||||
int warp_k_compute_offset) {
|
||||
if constexpr (kUnpackInterval != 1) {
|
||||
// unsupport now
|
||||
arch::device_breakpoint();
|
||||
}
|
||||
|
||||
typename Uint2Converter::source_type source_frag;
|
||||
|
||||
int in_offset = warp_k_compute_offset * kUnpackInterval;
|
||||
|
||||
uint8_t const* ptr_input = reinterpret_cast<uint8_t const*>(&input_frag);
|
||||
uint8_t* ptr_source = reinterpret_cast<uint8_t *>(&source_frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
|
||||
ptr_source[mma_n_iter] = ptr_input[mma_n_iter * kUnpackFactor + in_offset];
|
||||
}
|
||||
FragmentInputUnpack unpacked_frag = Uint2Converter::convert(source_frag, code_scale_frag, code_zp_frag);
|
||||
|
||||
// dequantize local_scale
|
||||
if (warp_k_compute_offset == 0) {
|
||||
using LocalScaleConverter = detail::LocalScaleConverter<ElementOperand, FragmentLocalScale::kElements>;
|
||||
|
||||
// special for TileRows = 64
|
||||
int local_scale_shift = (((tb_offset_k / kGroupSize) + 1) & 1) * 4;
|
||||
LocalScaleConverter::Apply(local_scale_frag, super_scale_frag, scale_frag_, local_scale_shift);
|
||||
}
|
||||
|
||||
// unscale
|
||||
// After applying LOP3 optimizations for performance, the B operand requires data rearrangement.
|
||||
// reorder: [0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15]
|
||||
const int kWarpIterationsAlongK = FragmentOutput::kElements / kWarpIterationsAlongN;
|
||||
|
||||
using Type = typename detail::DataTypeTraits<ElementOperand>::Type;
|
||||
using DualType = typename detail::DataTypeTraits<ElementOperand>::DualType;
|
||||
|
||||
Type* output_ptr = reinterpret_cast<Type *>(&output_frag);
|
||||
DualType const* unpacked_ptr = reinterpret_cast<DualType const*>(&unpacked_frag);
|
||||
DualType const* scale_ptr = reinterpret_cast<DualType const*>(&scale_frag_);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; mma_n_iter += 2) {
|
||||
int mapped_idx_base = (mma_n_iter / 2) * kWarpIterationsAlongK;
|
||||
|
||||
DualType scalex2 = scale_ptr[mma_n_iter / 2];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_k_iter = 0; mma_k_iter < kWarpIterationsAlongK; ++mma_k_iter) {
|
||||
DualType unpacked_valuex2 = unpacked_ptr[mapped_idx_base + mma_k_iter];
|
||||
DualType scaled_value = __hmul2(unpacked_valuex2, scalex2);
|
||||
output_ptr[mma_n_iter * kWarpIterationsAlongK + mma_k_iter] = scaled_value.x;
|
||||
output_ptr[(mma_n_iter + 1) * kWarpIterationsAlongK + mma_k_iter] = scaled_value.y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an offset to pointer in units of elements.
|
||||
/// Only group-wise params needs.
|
||||
CUTLASS_DEVICE
|
||||
void add_pointer_offset(int64_t const& offset) {
|
||||
pointer_local_scale_ += offset;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
@@ -39,25 +39,18 @@
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||
return res;
|
||||
}
|
||||
namespace cutlass
|
||||
{
|
||||
|
||||
// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low
|
||||
// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally
|
||||
// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned.
|
||||
// This converter will uninterleave the data and subtract the bias while converting to the result type.
|
||||
template <typename T, typename S, int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter;
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4>
|
||||
@@ -447,329 +440,6 @@ struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint2b_t, 16>
|
||||
{
|
||||
using result_type = Array<half_t, 16>;
|
||||
using source_type = Array<uint2b_t, 16>;
|
||||
|
||||
using ScaleComputeT = float;
|
||||
using code_type = Array<ScaleComputeT, 4>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source, ScaleComputeT code_scale, ScaleComputeT code_zp)
|
||||
{
|
||||
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// 2^23 = 8388608
|
||||
static constexpr uint32_t FP32_BASE = 0x4B000000;
|
||||
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
|
||||
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
|
||||
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
|
||||
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
|
||||
|
||||
int32_t decode_value[4];
|
||||
ScaleComputeT new_code_zp = code_zp + 0.5f;
|
||||
|
||||
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp));
|
||||
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp));
|
||||
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp));
|
||||
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp));
|
||||
|
||||
return convert_impl(decode_value);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp)
|
||||
{
|
||||
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// 2^23 = 8388608
|
||||
static constexpr uint32_t FP32_BASE = 0x4B000000;
|
||||
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
|
||||
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
|
||||
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
|
||||
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
|
||||
|
||||
int32_t decode_value[4];
|
||||
|
||||
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f));
|
||||
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f));
|
||||
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f));
|
||||
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f));
|
||||
|
||||
return convert_impl(decode_value);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert_impl(int32_t* decode_value)
|
||||
{
|
||||
result_type result;
|
||||
static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA;
|
||||
|
||||
static constexpr uint32_t MASK = 0x003F003F;
|
||||
// 2^10 = 1024
|
||||
static constexpr uint32_t EX = 0x64006400;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
|
||||
int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410);
|
||||
int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410);
|
||||
|
||||
h[0] = lop3<immLut>(q0 >> 9, MASK, EX);
|
||||
h[1] = lop3<immLut>(q0 >> 6, MASK, EX);
|
||||
h[2] = lop3<immLut>(q0 >> 3, MASK, EX);
|
||||
h[3] = lop3<immLut>(q0, MASK, EX);
|
||||
|
||||
h[4] = lop3<immLut>(q1 >> 9, MASK, EX);
|
||||
h[5] = lop3<immLut>(q1 >> 6, MASK, EX);
|
||||
h[6] = lop3<immLut>(q1 >> 3, MASK, EX);
|
||||
h[7] = lop3<immLut>(q1, MASK, EX);
|
||||
|
||||
// 1024 + 32 = 1056
|
||||
static constexpr uint32_t SUB = 0x64206420;
|
||||
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB));
|
||||
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s, ScaleComputeT code_scale, ScaleComputeT code_zp)
|
||||
{
|
||||
return convert(s, code_scale, code_zp);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint2b_t, 16>
|
||||
{
|
||||
using result_type = Array<bfloat16_t, 16>;
|
||||
using source_type = Array<uint2b_t, 16>;
|
||||
|
||||
using ScaleComputeT = float;
|
||||
using code_type = Array<ScaleComputeT, 4>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source, ScaleComputeT code_scale, ScaleComputeT code_zp)
|
||||
{
|
||||
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// 2^23 = 8388608
|
||||
static constexpr uint32_t FP32_BASE = 0x4B000000;
|
||||
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
|
||||
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
|
||||
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
|
||||
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
|
||||
|
||||
int32_t decode_value[4];
|
||||
ScaleComputeT new_code_zp = code_zp + 0.5f;
|
||||
|
||||
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp));
|
||||
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp));
|
||||
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp));
|
||||
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp));
|
||||
|
||||
return convert_impl(decode_value);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp)
|
||||
{
|
||||
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// 2^23 = 8388608
|
||||
static constexpr uint32_t FP32_BASE = 0x4B000000;
|
||||
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
|
||||
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
|
||||
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
|
||||
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
|
||||
|
||||
int32_t decode_value[4];
|
||||
|
||||
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f));
|
||||
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f));
|
||||
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f));
|
||||
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f));
|
||||
|
||||
return convert_impl(decode_value);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert_impl(int32_t* decode_value)
|
||||
{
|
||||
result_type result;
|
||||
|
||||
static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA;
|
||||
static constexpr uint32_t MASK = 0x003F003F;
|
||||
// 2^7 = 128
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
|
||||
int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410);
|
||||
int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410);
|
||||
|
||||
h[0] = lop3<immLut>(q0 >> 9, MASK, EX);
|
||||
h[1] = lop3<immLut>(q0 >> 6, MASK, EX);
|
||||
h[2] = lop3<immLut>(q0 >> 3, MASK, EX);
|
||||
h[3] = lop3<immLut>(q0, MASK, EX);
|
||||
|
||||
h[4] = lop3<immLut>(q1 >> 9, MASK, EX);
|
||||
h[5] = lop3<immLut>(q1 >> 6, MASK, EX);
|
||||
h[6] = lop3<immLut>(q1 >> 3, MASK, EX);
|
||||
h[7] = lop3<immLut>(q1, MASK, EX);
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(ENABLE_BF16))
|
||||
// 128 + 32 = 160
|
||||
static constexpr uint32_t SUB = 0x43204320;
|
||||
|
||||
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB));
|
||||
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB));
|
||||
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB));
|
||||
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB));
|
||||
|
||||
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB));
|
||||
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB));
|
||||
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB));
|
||||
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB));
|
||||
#else
|
||||
// 1.0
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
// -160
|
||||
static constexpr uint32_t ADD = 0xC320C320;
|
||||
|
||||
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MUL), "r"(ADD));
|
||||
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MUL), "r"(ADD));
|
||||
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MUL), "r"(ADD));
|
||||
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MUL), "r"(ADD));
|
||||
|
||||
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[4]) : "r"(h[4]), "r"(MUL), "r"(ADD));
|
||||
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[5]) : "r"(h[5]), "r"(MUL), "r"(ADD));
|
||||
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[6]) : "r"(h[6]), "r"(MUL), "r"(ADD));
|
||||
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[7]) : "r"(h[7]), "r"(MUL), "r"(ADD));
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s, ScaleComputeT code_scale, ScaleComputeT code_zp)
|
||||
{
|
||||
return convert(s, code_scale, code_zp);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<T, uint2b_t, N>
|
||||
{
|
||||
static_assert(platform::is_same<T, half_t>::value || platform::is_same<T, bfloat16_t>::value,
|
||||
"T must be fp16 or bf16");
|
||||
|
||||
static constexpr int kVecWidth = 16;
|
||||
static_assert(!(N % kVecWidth), "N must be multiple of 16.");
|
||||
|
||||
using result_type = Array<T, N>;
|
||||
using source_type = Array<uint2b_t, N>;
|
||||
using code_type = Array<float, N / kVecWidth>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp)
|
||||
{
|
||||
using scalar_result_type = typename result_type::Element;
|
||||
using scalar_source_type = typename source_type::Element;
|
||||
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, kVecWidth>
|
||||
convert_vector_;
|
||||
|
||||
result_type result;
|
||||
using vec_result = Array<scalar_result_type, kVecWidth>;
|
||||
using vec_source = Array<scalar_source_type, kVecWidth>;
|
||||
|
||||
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
||||
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / kVecWidth; ++i)
|
||||
{
|
||||
result_ptr[i] = convert_vector_(source_ptr[i], code_scale[i], code_zp[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source, Array<float, N / 4> const& code_scale, Array<float, N / 4> const& code_zp)
|
||||
{
|
||||
using scalar_result_type = typename result_type::Element;
|
||||
using scalar_source_type = typename source_type::Element;
|
||||
using Converter = FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, kVecWidth>;
|
||||
|
||||
result_type result;
|
||||
using vec_result = typename Converter::result_type;
|
||||
using vec_source = typename Converter::source_type;
|
||||
using vec_code = typename Converter::code_type;
|
||||
|
||||
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
||||
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
||||
vec_code const* code_scale_ptr = reinterpret_cast<vec_code const*>(&code_scale);
|
||||
vec_code const* code_zp_ptr = reinterpret_cast<vec_code const*>(&code_zp);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / kVecWidth; ++i)
|
||||
{
|
||||
result_ptr[i] = Converter::convert(source_ptr[i], code_scale_ptr[i], code_zp_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s, code_type const& code_scale, code_type const& code_zp)
|
||||
{
|
||||
return convert(s, code_scale, code_zp);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
@@ -125,13 +125,10 @@ struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt2> {
|
||||
static constexpr int32_t kNumPackedValues = 4;
|
||||
static constexpr int32_t kPackedSize = 16;
|
||||
|
||||
using LocalScaleType = uint4b_t;
|
||||
using CodeScaleZpType = float;
|
||||
|
||||
struct Arguments {
|
||||
uint8_t *local_scale_ptr; // quanted 4-bits
|
||||
float *code_scale_ptr;
|
||||
float *code_zp_ptr;
|
||||
const uint8_t *local_scale_ptr; // quanted 4-bits
|
||||
const float *code_scale_ptr;
|
||||
const float *code_zp_ptr;
|
||||
};
|
||||
|
||||
CUTLASS_DEVICE
|
||||
|
@@ -117,7 +117,7 @@ class LeftGELUAndMul {
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(FragmentAccumulator const &lhs,
|
||||
FragmentAccumulator const &rhs) const {
|
||||
// Convert source to internal compute numeric type
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
|
||||
accumulator_to_compute;
|
||||
|
||||
|
@@ -117,7 +117,7 @@ class LeftSiLUAndMul {
|
||||
CUTLASS_HOST_DEVICE
|
||||
FragmentOutput operator()(FragmentAccumulator const &lhs,
|
||||
FragmentAccumulator const &rhs) const {
|
||||
// Convert source to internal compute numeric type
|
||||
// Convert source to interal compute numeric type
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
|
||||
accumulator_to_compute;
|
||||
|
||||
|
@@ -92,7 +92,7 @@ class DualMmaBase {
|
||||
Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM operations
|
||||
/// Number of warp-level GEMM oeprations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator0::Policy::MmaShape::kK);
|
||||
|
||||
|
@@ -43,6 +43,7 @@
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -774,54 +775,17 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
template <WintQuantMethod QuantMethod, typename dummy>
|
||||
struct KernelRunner<QuantMethod, true, dummy> {
|
||||
using WeightQuantTraits = WintQuantTraits<ElementA, QuantMethod>;
|
||||
using MmaQuantArguments = typename Mma::QuantParamsAccessor::Arguments;
|
||||
using QuantArguments = typename WeightQuantTraits::Arguments;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static MmaQuantArguments prepare_quant_args(
|
||||
Params const& params, cutlass::gemm::GemmCoord const& threadblock_offset,
|
||||
int64_t problem_idx, const int32_t gemm_k, const int32_t gemm_n, const int thread_idx) {
|
||||
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
|
||||
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
|
||||
cutlass::MatrixCoord tb_offset_local_scale{0, threadblock_offset.n() * 2};
|
||||
|
||||
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * gemm_n;
|
||||
typename Mma::QuantParamsAccessor::IteratorSuperScale iterator_super_scale(
|
||||
Mma::QuantParamsAccessor::LayoutSuperScale(gemm_n),
|
||||
weight_scale_ptr,
|
||||
{1, gemm_n},
|
||||
thread_idx,
|
||||
tb_offset_scale);
|
||||
|
||||
int local_scale_pointer_offset = ((ThreadblockShape::kK + 127) / 128) * (gemm_n * 2);
|
||||
int64_t offset_in_bytes = problem_idx * gemm_k * gemm_n / 128;
|
||||
uint4b_t *local_scale_ptr = reinterpret_cast<uint4b_t *>(params.local_scale + offset_in_bytes);
|
||||
|
||||
typename Mma::QuantParamsAccessor::IteratorLocalScale iterator_local_scale(
|
||||
Mma::QuantParamsAccessor::LayoutLocalScale(gemm_n * 2),
|
||||
local_scale_ptr,
|
||||
{(gemm_k + 127) / 128, gemm_n * 2},
|
||||
thread_idx,
|
||||
tb_offset_local_scale);
|
||||
|
||||
float* code_scale_ptr = params.code_scale + problem_idx * gemm_n;
|
||||
typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_scale(
|
||||
Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n),
|
||||
code_scale_ptr,
|
||||
{1, gemm_n},
|
||||
thread_idx,
|
||||
tb_offset_scale);
|
||||
|
||||
float* code_zp_ptr = params.code_zp + problem_idx * gemm_n;
|
||||
typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_zp(
|
||||
Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n),
|
||||
code_zp_ptr,
|
||||
{1, gemm_n},
|
||||
thread_idx,
|
||||
tb_offset_scale);
|
||||
|
||||
MmaQuantArguments mma_quant_args(
|
||||
iterator_super_scale, iterator_local_scale, iterator_code_scale, iterator_code_zp, local_scale_pointer_offset);
|
||||
return mma_quant_args;
|
||||
static QuantArguments get_quant_args(Params const& params, int32_t problem_idx, const int64_t gemm_k, const int64_t gemm_n) {
|
||||
QuantArguments quant_args;
|
||||
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
|
||||
quant_args.local_scale_ptr = params.local_scale + problem_idx * gemm_k * gemm_n / 128;
|
||||
quant_args.code_scale_ptr = params.code_scale + problem_idx * gemm_n;
|
||||
quant_args.code_zp_ptr = params.code_zp + problem_idx * gemm_n;
|
||||
}
|
||||
return quant_args;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
@@ -850,6 +814,9 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
kInterleave >= 1,
|
||||
"B must be row major/col major OR col major interleaved.");
|
||||
|
||||
// LayoutB should be RowMajor
|
||||
using TileDequanterB = cutlass::gemm::threadblock::TileDequanter<ElementA, ElementScale, ThreadblockShape::kK, ThreadblockShape::kN, kStages, kThreadCount, QuantMethod>;
|
||||
|
||||
//
|
||||
// Problem visitor.
|
||||
//
|
||||
@@ -876,6 +843,12 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT
|
||||
0);
|
||||
|
||||
// begin address offset for weight_scale.
|
||||
ElementScale* weight_scale_ptr =
|
||||
params.weight_scales ? params.weight_scales + problem_idx * problem_size.n() : nullptr;
|
||||
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
|
||||
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
|
||||
|
||||
// Load element pointers. Exchange pointers and strides if working on
|
||||
// the transpose
|
||||
int64_t rows_to_jump = 0;
|
||||
@@ -893,20 +866,42 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
// the begin threadblock_offset of A, which holds the same row id with C
|
||||
cutlass::MatrixCoord tb_offset_A{threadblock_offset.m(), 0};
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
0,
|
||||
};
|
||||
|
||||
// begin address offset for B for current problem_idx, totally num_experts problems
|
||||
char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT
|
||||
problem_idx * bytes_per_expert_matrix; // NOLINT
|
||||
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
|
||||
|
||||
typename LayoutB::LongIndex ldm_B =
|
||||
platform::is_same<layout::RowMajor, LayoutB>::value
|
||||
? gemm_n
|
||||
: gemm_k * kInterleave;
|
||||
typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns;
|
||||
|
||||
// the begin threadblock_offset of B, which holds the same column id with C
|
||||
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
|
||||
cutlass::MatrixCoord tb_offset_B{0,
|
||||
threadblock_offset.n() / kInterleave};
|
||||
|
||||
cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave};
|
||||
cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns};
|
||||
|
||||
MmaElementB* smem_unzip_B_ptr = nullptr;
|
||||
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
|
||||
smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr();
|
||||
}
|
||||
QuantArguments quant_args = get_quant_args(params, problem_idx, gemm_k, gemm_n);
|
||||
TileDequanterB tile_dequanter_B(smem_unzip_B_ptr,
|
||||
byte_ptr_B,
|
||||
ldm_B,
|
||||
extent_B,
|
||||
tb_offset_B,
|
||||
weight_scale_ptr,
|
||||
tb_offset_scale,
|
||||
quant_args);
|
||||
MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr();
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
@@ -919,21 +914,20 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
LayoutB(ldm_B),
|
||||
LayoutB(TileDequanterB::kUseSharedMemory ? ldm_B_shared : ldm_B),
|
||||
ptr_B,
|
||||
extent_B,
|
||||
TileDequanterB::kUseSharedMemory ? extent_B_shared : extent_B,
|
||||
thread_idx,
|
||||
tb_offset_B);
|
||||
|
||||
MmaQuantArguments mma_quant_args = prepare_quant_args(
|
||||
params, threadblock_offset, problem_idx, gemm_k, gemm_n, thread_idx);
|
||||
TileDequanterB::kUseSharedMemory ? cutlass::make_Coord(0, 0) : tb_offset_B);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
@@ -956,7 +950,7 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
accumulators,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
mma_quant_args,
|
||||
tile_dequanter_B,
|
||||
accumulators);
|
||||
|
||||
//
|
||||
|
@@ -205,7 +205,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
threadblock_count,
|
||||
epilogue_op,
|
||||
reinterpret_cast<const ElementType*>(A),
|
||||
reinterpret_cast<const CutlassMmaKernelType*>(B),
|
||||
reinterpret_cast<const CutlassMmaWeightType*>(B),
|
||||
reinterpret_cast<const ElementType*>(weight_scales),
|
||||
reinterpret_cast<const ElementType*>(biases),
|
||||
reinterpret_cast<ElementType*>(C),
|
||||
|
@@ -219,7 +219,7 @@ class EpilogueVisitorPerRowPerColNf4 {
|
||||
iterator_C_.clear_mask();
|
||||
}
|
||||
// NOTE(wangbojun) Currently, this kernel don't hanve implantention for
|
||||
// adding elementwise beta, we keep this here for future usage beta_ =
|
||||
// adding elementwise beta, we keep this here for future useage beta_ =
|
||||
// (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr :
|
||||
// params.elementwise.beta); if (beta_ == ElementAccumulator()) {
|
||||
// iterator_C_.clear_mask();
|
||||
|
@@ -176,7 +176,7 @@ struct Nf4DefaultIteratorsTensorOp<cutlass::bfloat16_t,
|
||||
///
|
||||
/// Satisfies: ReadableTileIterator
|
||||
///
|
||||
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
|
||||
template <typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap)
|
||||
>
|
||||
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
|
||||
public:
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user