mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-01 20:32:52 +08:00
Compare commits
98 Commits
release/2.
...
release/2.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f6f9c12b87 | ||
|
|
8d2aaf3ba4 | ||
|
|
c13e6ae481 | ||
|
|
f660188a85 | ||
|
|
4178c110d2 | ||
|
|
adeee84dd6 | ||
|
|
e0946ae128 | ||
|
|
836ba294fc | ||
|
|
b489943261 | ||
|
|
e42dc8c694 | ||
|
|
63a03ee152 | ||
|
|
9cc2c99539 | ||
|
|
31e32b5821 | ||
|
|
aebe12a58d | ||
|
|
8fdb950e9f | ||
|
|
a460462d2a | ||
|
|
cb8d87b945 | ||
|
|
de4feff147 | ||
|
|
f38b174a75 | ||
|
|
6b47773bd6 | ||
|
|
0358329946 | ||
|
|
01f6934162 | ||
|
|
7bdc6f41e5 | ||
|
|
bba279cf38 | ||
|
|
4f460db556 | ||
|
|
74d7b9151d | ||
|
|
0fa28b1068 | ||
|
|
cffde70949 | ||
|
|
7f9a9b37f3 | ||
|
|
b41988f4bc | ||
|
|
7ccbcc5a62 | ||
|
|
fbb4e0f8d1 | ||
|
|
4e8ba62241 | ||
|
|
7e3148ed81 | ||
|
|
4f8ff478b3 | ||
|
|
c4098d56a0 | ||
|
|
a6b161b007 | ||
|
|
7272afe3dc | ||
|
|
dfc94371ee | ||
|
|
35b8362804 | ||
|
|
d43c2f2577 | ||
|
|
14df2c59da | ||
|
|
934071578a | ||
|
|
36a58f487c | ||
|
|
d40a1046de | ||
|
|
fa2369271d | ||
|
|
8903f937f9 | ||
|
|
1023a67765 | ||
|
|
d43549953c | ||
|
|
c7c1627456 | ||
|
|
d6bf6de5e6 | ||
|
|
38e734e183 | ||
|
|
051e4a881c | ||
|
|
b2bb37d7c0 | ||
|
|
c6e2a37a95 | ||
|
|
8d77c1cb51 | ||
|
|
41cd3e24c9 | ||
|
|
11b18e5ef0 | ||
|
|
e2c764fd5a | ||
|
|
2d975e16b0 | ||
|
|
8915c8411d | ||
|
|
77c1bd0813 | ||
|
|
473cde779f | ||
|
|
335d1c8e8f | ||
|
|
173e4df982 | ||
|
|
199f88ce1e | ||
|
|
55ebe855c0 | ||
|
|
deb7ad205f | ||
|
|
e9f72df918 | ||
|
|
8567ada09e | ||
|
|
afcde19277 | ||
|
|
d40d3a5a4f | ||
|
|
b8d0f1c081 | ||
|
|
8550e19008 | ||
|
|
a0c03510c0 | ||
|
|
fb1e0d6a87 | ||
|
|
fbf0e9d2aa | ||
|
|
8c0e7d6fe9 | ||
|
|
b56b015d85 | ||
|
|
1432e336d7 | ||
|
|
9213a58a06 | ||
|
|
87ef0f5d30 | ||
|
|
abcd2148c0 | ||
|
|
05b6591c23 | ||
|
|
42402c80e9 | ||
|
|
1968c65849 | ||
|
|
37cb37b7f2 | ||
|
|
f975f7de2f | ||
|
|
174510180a | ||
|
|
5cda326ba2 | ||
|
|
a6c8f17431 | ||
|
|
cd09384a14 | ||
|
|
0f42771a84 | ||
|
|
d1d063e4af | ||
|
|
a86b35ab49 | ||
|
|
0cdbc950b5 | ||
|
|
2b0a745d57 | ||
|
|
1953c7c759 |
@@ -16,7 +16,7 @@
|
|||||||
---
|
---
|
||||||
Language: Cpp
|
Language: Cpp
|
||||||
BasedOnStyle: Google
|
BasedOnStyle: Google
|
||||||
IndentWidth: 2
|
IndentWidth: 4
|
||||||
TabWidth: 2
|
TabWidth: 2
|
||||||
ContinuationIndentWidth: 4
|
ContinuationIndentWidth: 4
|
||||||
AccessModifierOffset: -1 # The private/protected/public has no indent in class
|
AccessModifierOffset: -1 # The private/protected/public has no indent in class
|
||||||
|
|||||||
30
.github/actions/rerun-workflow/action.yml
vendored
30
.github/actions/rerun-workflow/action.yml
vendored
@@ -1,30 +0,0 @@
|
|||||||
name: 'Rerun Workflow'
|
|
||||||
description: 'Re-run GitHub Actions workflow for a given Pull Request'
|
|
||||||
inputs:
|
|
||||||
GITHUB_TOKEN:
|
|
||||||
description: 'GitHub token with repo scope'
|
|
||||||
required: true
|
|
||||||
OWNER:
|
|
||||||
description: 'Repository owner'
|
|
||||||
required: true
|
|
||||||
REPO:
|
|
||||||
description: 'Repository name'
|
|
||||||
required: true
|
|
||||||
PR_ID:
|
|
||||||
description: 'Pull Request ID'
|
|
||||||
required: true
|
|
||||||
JOB_NAME:
|
|
||||||
description: 'Job name to rerun'
|
|
||||||
required: true
|
|
||||||
|
|
||||||
runs:
|
|
||||||
using: 'composite'
|
|
||||||
steps:
|
|
||||||
- run: bash ./.github/actions/rerun-workflow/rerun.sh
|
|
||||||
shell: bash
|
|
||||||
env:
|
|
||||||
GITHUB_TOKEN: ${{ inputs.GITHUB_TOKEN }}
|
|
||||||
OWNER: ${{ inputs.OWNER }}
|
|
||||||
REPO: ${{ inputs.REPO }}
|
|
||||||
PR_ID: ${{ inputs.PR_ID }}
|
|
||||||
JOB_NAME: ${{ inputs.JOB_NAME }}
|
|
||||||
77
.github/actions/rerun-workflow/rerun.sh
vendored
77
.github/actions/rerun-workflow/rerun.sh
vendored
@@ -1,77 +0,0 @@
|
|||||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
COMMIT_SHA=$(curl -s -H "Authorization: token $GITHUB_TOKEN" \
|
|
||||||
"https://api.github.com/repos/$OWNER/$REPO/pulls/$PR_ID" | jq -r '.head.sha')
|
|
||||||
|
|
||||||
echo "Commit SHA: $COMMIT_SHA"
|
|
||||||
|
|
||||||
response=$(curl -s -H "Authorization: token $GITHUB_TOKEN" \
|
|
||||||
"https://api.github.com/repos/$OWNER/$REPO/actions/runs?head_sha=$COMMIT_SHA&per_page=100")
|
|
||||||
|
|
||||||
echo "Response: $response"
|
|
||||||
|
|
||||||
run_ids=$(echo "$response" | jq -r '.workflow_runs[].id')
|
|
||||||
|
|
||||||
if [ -n "$run_ids" ]; then
|
|
||||||
echo "Found run_ids for commit $COMMIT_SHA: $run_ids"
|
|
||||||
|
|
||||||
for run_id in $run_ids; do
|
|
||||||
if [ "$JOB_NAME" = "all-failed" ]; then
|
|
||||||
echo "Rerunning all failed jobs for run_id: $run_id"
|
|
||||||
|
|
||||||
rerun_response=$(curl -X POST -s -w "%{http_code}" -o /dev/null \
|
|
||||||
-H "Accept: application/vnd.github.v3+json" \
|
|
||||||
-H "Authorization: Bearer $GITHUB_TOKEN" \
|
|
||||||
"https://api.github.com/repos/$OWNER/$REPO/actions/runs/$run_id/rerun-failed-jobs")
|
|
||||||
if [ "$rerun_response" -eq 201 ]; then
|
|
||||||
echo "Successfully requested rerun for all blocked jobs in run_id: $run_id"
|
|
||||||
else
|
|
||||||
echo "Failed to request rerun for run_id: $run_id with status code $rerun_response"
|
|
||||||
fi
|
|
||||||
|
|
||||||
else
|
|
||||||
jobs_response=$(curl -s -H "Authorization: token $GITHUB_TOKEN" \
|
|
||||||
"https://api.github.com/repos/$OWNER/$REPO/actions/runs/$run_id/jobs")
|
|
||||||
|
|
||||||
echo "Jobs Response for run_id $run_id: $jobs_response"
|
|
||||||
|
|
||||||
# if [[ "$JOB_NAME" == *"bypass"* ]]; then
|
|
||||||
block_jobs=$(echo "$jobs_response" | jq -r --arg job_name "$JOB_NAME" \
|
|
||||||
'.jobs[] | select(.name == $job_name) | .id')
|
|
||||||
# else
|
|
||||||
# block_jobs=$(echo "$jobs_response" | jq -r --arg job_name "$JOB_NAME" \
|
|
||||||
# '.jobs[] | select(.name == $job_name and .conclusion != "success") | .id')
|
|
||||||
# fi
|
|
||||||
|
|
||||||
if [ -n "$block_jobs" ]; then
|
|
||||||
echo "Found block jobs for run_id $run_id: $block_jobs"
|
|
||||||
|
|
||||||
for job_id in $block_jobs; do
|
|
||||||
echo "Rerunning job_id: $job_id"
|
|
||||||
curl -X POST -H "Accept: application/vnd.github.v3+json" \
|
|
||||||
-H "Authorization: token $GITHUB_TOKEN" \
|
|
||||||
"https://api.github.com/repos/$OWNER/$REPO/actions/jobs/$job_id/rerun"
|
|
||||||
done
|
|
||||||
else
|
|
||||||
echo "No block jobs found for run_id $run_id with name $JOB_NAME."
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
else
|
|
||||||
echo "No matching workflow runs found for commit $COMMIT_SHA."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
30
.github/pull_request_template.md
vendored
30
.github/pull_request_template.md
vendored
@@ -1,30 +0,0 @@
|
|||||||
<!-- TemplateReference: https://github.com/PaddlePaddle/FastDeploy/blob/develop/.github/pull_request_template.md -->
|
|
||||||
|
|
||||||
<!-- Thank you for your contribution! Please follow these guidelines to enhance your pull request. If anything is unclear, submit your PR and reach out to maintainers for assistance. -->
|
|
||||||
|
|
||||||
## Motivation
|
|
||||||
|
|
||||||
<!-- Describe the purpose and goals of this pull request. -->
|
|
||||||
|
|
||||||
## Modifications
|
|
||||||
|
|
||||||
<!-- Detail the changes made in this pull request. -->
|
|
||||||
|
|
||||||
## Usage or Command
|
|
||||||
|
|
||||||
<!-- You should provide the usage if this pr is about the new function. -->
|
|
||||||
<!-- You should provide the command to run if this pr is about the performance optimization or fixing bug. -->
|
|
||||||
|
|
||||||
## Accuracy Tests
|
|
||||||
|
|
||||||
<!-- If this pull request affects model outputs (e.g., changes to the kernel or model forward code), provide accuracy test results. -->
|
|
||||||
|
|
||||||
## Checklist
|
|
||||||
|
|
||||||
- [ ] Add at least a tag in the PR title.
|
|
||||||
- Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
|
|
||||||
- You can add new tags based on the PR content, but the semantics must be clear.
|
|
||||||
- [ ] Format your code, run `pre-commit` before commit.
|
|
||||||
- [ ] Add unit tests. Please write the reason in this PR if no unit tests.
|
|
||||||
- [ ] Provide accuracy results.
|
|
||||||
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.
|
|
||||||
10
.github/workflows/_accuracy_test.yml
vendored
10
.github/workflows/_accuracy_test.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
|||||||
FULL_REPO="${{ github.repository }}"
|
FULL_REPO="${{ github.repository }}"
|
||||||
REPO_NAME="${FULL_REPO##*/}"
|
REPO_NAME="${FULL_REPO##*/}"
|
||||||
BASE_BRANCH="${{ github.base_ref }}"
|
BASE_BRANCH="${{ github.base_ref }}"
|
||||||
docker pull ${docker_image}
|
|
||||||
# Clean the repository directory before starting
|
# Clean the repository directory before starting
|
||||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||||
-e "REPO_NAME=${REPO_NAME}" \
|
-e "REPO_NAME=${REPO_NAME}" \
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
'
|
'
|
||||||
|
|
||||||
wget -q --no-proxy ${fd_archive_url}
|
wget -q ${fd_archive_url}
|
||||||
tar -xf FastDeploy.tar.gz
|
tar -xf FastDeploy.tar.gz
|
||||||
rm -rf FastDeploy.tar.gz
|
rm -rf FastDeploy.tar.gz
|
||||||
cd FastDeploy
|
cd FastDeploy
|
||||||
@@ -143,7 +143,7 @@ jobs:
|
|||||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||||
-e TZ="Asia/Shanghai" \
|
-e TZ="Asia/Shanghai" \
|
||||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||||
|
|
||||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||||
|
|
||||||
@@ -155,13 +155,11 @@ jobs:
|
|||||||
./llm-deploy-linux-amd64 -python python3.10 \
|
./llm-deploy-linux-amd64 -python python3.10 \
|
||||||
-model_name ERNIE-4.5-0.3B-Paddle \
|
-model_name ERNIE-4.5-0.3B-Paddle \
|
||||||
-model_path /MODELDATA \
|
-model_path /MODELDATA \
|
||||||
--skip install,model
|
--skip install
|
||||||
|
|
||||||
git config --global --add safe.directory /workspace/FastDeploy
|
git config --global --add safe.directory /workspace/FastDeploy
|
||||||
cd FastDeploy
|
cd FastDeploy
|
||||||
pushd tests/ce/deploy
|
pushd tests/ce/deploy
|
||||||
ps -ef | grep "${FD_CACHE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
|
|
||||||
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
|
|
||||||
python3.10 deploy.py > dd.log 2>&1 &
|
python3.10 deploy.py > dd.log 2>&1 &
|
||||||
sleep 3
|
sleep 3
|
||||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
||||||
|
|||||||
10
.github/workflows/_base_test.yml
vendored
10
.github/workflows/_base_test.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
|||||||
FULL_REPO="${{ github.repository }}"
|
FULL_REPO="${{ github.repository }}"
|
||||||
REPO_NAME="${FULL_REPO##*/}"
|
REPO_NAME="${FULL_REPO##*/}"
|
||||||
BASE_BRANCH="${{ github.base_ref }}"
|
BASE_BRANCH="${{ github.base_ref }}"
|
||||||
docker pull ${docker_image}
|
|
||||||
# Clean the repository directory before starting
|
# Clean the repository directory before starting
|
||||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||||
-e "REPO_NAME=${REPO_NAME}" \
|
-e "REPO_NAME=${REPO_NAME}" \
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
'
|
'
|
||||||
|
|
||||||
wget -q --no-proxy ${fd_archive_url}
|
wget -q ${fd_archive_url}
|
||||||
tar -xf FastDeploy.tar.gz
|
tar -xf FastDeploy.tar.gz
|
||||||
rm -rf FastDeploy.tar.gz
|
rm -rf FastDeploy.tar.gz
|
||||||
cd FastDeploy
|
cd FastDeploy
|
||||||
@@ -143,7 +143,7 @@ jobs:
|
|||||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||||
-e TZ="Asia/Shanghai" \
|
-e TZ="Asia/Shanghai" \
|
||||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||||
|
|
||||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||||
|
|
||||||
@@ -155,13 +155,11 @@ jobs:
|
|||||||
./llm-deploy-linux-amd64 -python python3.10 \
|
./llm-deploy-linux-amd64 -python python3.10 \
|
||||||
-model_name ERNIE-4.5-0.3B-Paddle \
|
-model_name ERNIE-4.5-0.3B-Paddle \
|
||||||
-model_path /MODELDATA \
|
-model_path /MODELDATA \
|
||||||
--skip install,model
|
--skip install
|
||||||
|
|
||||||
git config --global --add safe.directory /workspace/FastDeploy
|
git config --global --add safe.directory /workspace/FastDeploy
|
||||||
cd FastDeploy
|
cd FastDeploy
|
||||||
pushd tests/ce/deploy
|
pushd tests/ce/deploy
|
||||||
ps -ef | grep "${FD_CACHE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
|
|
||||||
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
|
|
||||||
python3.10 deploy.py > dd.log 2>&1 &
|
python3.10 deploy.py > dd.log 2>&1 &
|
||||||
sleep 3
|
sleep 3
|
||||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
||||||
|
|||||||
14
.github/workflows/_build_linux.yml
vendored
14
.github/workflows/_build_linux.yml
vendored
@@ -55,7 +55,7 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
fd-build:
|
fd-build:
|
||||||
runs-on: [self-hosted, GPU-Build]
|
runs-on: [self-hosted, GPU-Build]
|
||||||
timeout-minutes: 360
|
timeout-minutes: 240
|
||||||
outputs:
|
outputs:
|
||||||
wheel_path: ${{ steps.set_output.outputs.wheel_path }}
|
wheel_path: ${{ steps.set_output.outputs.wheel_path }}
|
||||||
steps:
|
steps:
|
||||||
@@ -82,7 +82,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
'
|
'
|
||||||
|
|
||||||
wget -q --no-proxy ${fd_archive_url}
|
wget -q ${fd_archive_url}
|
||||||
tar -xf FastDeploy.tar.gz
|
tar -xf FastDeploy.tar.gz
|
||||||
rm -rf FastDeploy.tar.gz
|
rm -rf FastDeploy.tar.gz
|
||||||
cd FastDeploy
|
cd FastDeploy
|
||||||
@@ -106,12 +106,7 @@ jobs:
|
|||||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||||
gpu_id=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
gpu_id=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||||
|
|
||||||
IFS='/' read -ra parts <<< "${GITHUB_WORKSPACE}"
|
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||||
len=${#parts[@]}
|
|
||||||
CCACHE_DEFAULT_DIR="/$(IFS=/; echo "${parts[*]:1:$((len-5))}")"
|
|
||||||
echo "$CCACHE_DEFAULT_DIR"
|
|
||||||
|
|
||||||
CACHE_DIR="${CACHE_DIR:-$CCACHE_DEFAULT_DIR}"
|
|
||||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||||
touch "${CACHE_DIR}/gitconfig"
|
touch "${CACHE_DIR}/gitconfig"
|
||||||
@@ -132,7 +127,6 @@ jobs:
|
|||||||
-e "PADDLEVERSION=${PADDLEVERSION}" \
|
-e "PADDLEVERSION=${PADDLEVERSION}" \
|
||||||
-e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \
|
-e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \
|
||||||
-e "BRANCH_REF=${BRANCH_REF}" \
|
-e "BRANCH_REF=${BRANCH_REF}" \
|
||||||
-e "CCACHE_MAXSIZE=50G" \
|
|
||||||
--gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c '
|
--gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c '
|
||||||
if [[ -n "${FD_VERSION}" ]]; then
|
if [[ -n "${FD_VERSION}" ]]; then
|
||||||
export FASTDEPLOY_VERSION=${FD_VERSION}
|
export FASTDEPLOY_VERSION=${FD_VERSION}
|
||||||
@@ -155,7 +149,7 @@ jobs:
|
|||||||
elif [[ "${PADDLEVERSION}" != "" ]];then
|
elif [[ "${PADDLEVERSION}" != "" ]];then
|
||||||
python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||||
else
|
else
|
||||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||||
fi
|
fi
|
||||||
|
|
||||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||||
|
|||||||
73
.github/workflows/_ci_image_build.yml
vendored
73
.github/workflows/_ci_image_build.yml
vendored
@@ -1,73 +0,0 @@
|
|||||||
name: Docker Build
|
|
||||||
description: "FastDeploy CI Image Build"
|
|
||||||
|
|
||||||
on:
|
|
||||||
workflow_call:
|
|
||||||
inputs:
|
|
||||||
CI_DOCKER_IMAGE_NAME:
|
|
||||||
description: "Build Images"
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:cuda126-py310"
|
|
||||||
FASTDEPLOY_ARCHIVE_URL:
|
|
||||||
description: "URL of the compressed FastDeploy code archive."
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
DOCKER_IMAGE_NAME:
|
|
||||||
description: "Build Images"
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
default: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate"
|
|
||||||
outputs:
|
|
||||||
docker_name_precheck:
|
|
||||||
description: "Output path of the generated wheel"
|
|
||||||
value: ${{ jobs.docker_build.outputs.docker_name_precheck }}
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
docker_build:
|
|
||||||
runs-on: [self-hosted, Docker-Build]
|
|
||||||
outputs:
|
|
||||||
docker_name_precheck: ${{ steps.docker_build.outputs.docker_name_precheck }}
|
|
||||||
steps:
|
|
||||||
- name: Docker Build
|
|
||||||
id: docker_build
|
|
||||||
shell: bash
|
|
||||||
env:
|
|
||||||
docker_image_name: ${{ inputs.CI_DOCKER_IMAGE_NAME }}
|
|
||||||
docker_image: ${{ inputs.DOCKER_IMAGE_NAME }}
|
|
||||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
|
||||||
run: |
|
|
||||||
set -x
|
|
||||||
REPO="https://github.com/${{ github.repository }}.git"
|
|
||||||
FULL_REPO="${{ github.repository }}"
|
|
||||||
REPO_NAME="${FULL_REPO##*/}"
|
|
||||||
BASE_BRANCH="${{ github.base_ref }}"
|
|
||||||
|
|
||||||
# Clean the repository directory before starting
|
|
||||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
|
||||||
-e "REPO_NAME=${REPO_NAME}" \
|
|
||||||
${docker_image} /bin/bash -c '
|
|
||||||
if [ -d ${REPO_NAME} ]; then
|
|
||||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
|
||||||
rm -rf ${REPO_NAME}*
|
|
||||||
fi
|
|
||||||
'
|
|
||||||
|
|
||||||
wget -q --no-proxy ${fd_archive_url}
|
|
||||||
tar -xf FastDeploy.tar.gz
|
|
||||||
rm -rf FastDeploy.tar.gz
|
|
||||||
cd FastDeploy
|
|
||||||
git config --global user.name "FastDeployCI"
|
|
||||||
git config --global user.email "fastdeploy_ci@example.com"
|
|
||||||
git log -n 3 --oneline
|
|
||||||
|
|
||||||
# Docker Build
|
|
||||||
cd tools/dockerfile/
|
|
||||||
set -e
|
|
||||||
cp ../../requirements.txt ./
|
|
||||||
cp ../../scripts/unittest_requirement.txt ./
|
|
||||||
docker build -t ${docker_image_name} -f Dockerfile.ci . \
|
|
||||||
--network host \
|
|
||||||
--no-cache
|
|
||||||
docker push ${docker_image_name}
|
|
||||||
echo "docker_name_precheck=${docker_image_name}" >> $GITHUB_OUTPUT
|
|
||||||
11
.github/workflows/_logprob_test_linux.yml
vendored
11
.github/workflows/_logprob_test_linux.yml
vendored
@@ -32,7 +32,6 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
run_tests_logprob:
|
run_tests_logprob:
|
||||||
runs-on: [self-hosted, GPU-h20-1Cards]
|
runs-on: [self-hosted, GPU-h20-1Cards]
|
||||||
timeout-minutes: 60
|
|
||||||
steps:
|
steps:
|
||||||
- name: Code Prepare
|
- name: Code Prepare
|
||||||
shell: bash
|
shell: bash
|
||||||
@@ -40,7 +39,6 @@ jobs:
|
|||||||
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
docker_image: ${{ inputs.DOCKER_IMAGE }}
|
||||||
paddletest_archive_url: ${{ inputs.PADDLETEST_ARCHIVE_URL }}
|
paddletest_archive_url: ${{ inputs.PADDLETEST_ARCHIVE_URL }}
|
||||||
run: |
|
run: |
|
||||||
docker pull ${docker_image}
|
|
||||||
# Clean the repository directory before starting
|
# Clean the repository directory before starting
|
||||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||||
-e "REPO_NAME=${REPO_NAME}" \
|
-e "REPO_NAME=${REPO_NAME}" \
|
||||||
@@ -48,7 +46,7 @@ jobs:
|
|||||||
${docker_image} /bin/bash -c '
|
${docker_image} /bin/bash -c '
|
||||||
rm -rf /workspace/*
|
rm -rf /workspace/*
|
||||||
'
|
'
|
||||||
wget -q --no-proxy ${paddletest_archive_url}
|
wget -q ${paddletest_archive_url}
|
||||||
tar -xf PaddleTest.tar.gz
|
tar -xf PaddleTest.tar.gz
|
||||||
rm -rf PaddleTest.tar.gz
|
rm -rf PaddleTest.tar.gz
|
||||||
cd PaddleTest
|
cd PaddleTest
|
||||||
@@ -118,6 +116,7 @@ jobs:
|
|||||||
echo "Removing stale container: ${runner_name}"
|
echo "Removing stale container: ${runner_name}"
|
||||||
docker rm -f ${runner_name} || true
|
docker rm -f ${runner_name} || true
|
||||||
fi
|
fi
|
||||||
|
|
||||||
docker run --rm --ipc=host --pid=host --net=host \
|
docker run --rm --ipc=host --pid=host --net=host \
|
||||||
--name ${runner_name} \
|
--name ${runner_name} \
|
||||||
-v $(pwd):/workspace \
|
-v $(pwd):/workspace \
|
||||||
@@ -134,7 +133,7 @@ jobs:
|
|||||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||||
-e TZ="Asia/Shanghai" \
|
-e TZ="Asia/Shanghai" \
|
||||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||||
|
|
||||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||||
|
|
||||||
@@ -145,11 +144,9 @@ jobs:
|
|||||||
./llm-deploy-linux-amd64 -python python3.10 \
|
./llm-deploy-linux-amd64 -python python3.10 \
|
||||||
-model_name ERNIE-4.5-0.3B-Paddle \
|
-model_name ERNIE-4.5-0.3B-Paddle \
|
||||||
-model_path /MODELDATA \
|
-model_path /MODELDATA \
|
||||||
--skip install,model
|
--skip install
|
||||||
|
|
||||||
cd PaddleTest/framework/ServeTest
|
cd PaddleTest/framework/ServeTest
|
||||||
ps -ef | grep "${FD_CACHE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
|
|
||||||
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
|
|
||||||
python3.10 deploy.py > dd.log 2>&1 &
|
python3.10 deploy.py > dd.log 2>&1 &
|
||||||
sleep 3
|
sleep 3
|
||||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
||||||
|
|||||||
9
.github/workflows/_pre_ce_test.yml
vendored
9
.github/workflows/_pre_ce_test.yml
vendored
@@ -46,7 +46,7 @@ jobs:
|
|||||||
FULL_REPO="${{ github.repository }}"
|
FULL_REPO="${{ github.repository }}"
|
||||||
REPO_NAME="${FULL_REPO##*/}"
|
REPO_NAME="${FULL_REPO##*/}"
|
||||||
BASE_BRANCH="${{ github.base_ref }}"
|
BASE_BRANCH="${{ github.base_ref }}"
|
||||||
docker pull ${docker_image}
|
|
||||||
# Clean the repository directory before starting
|
# Clean the repository directory before starting
|
||||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||||
-e "REPO_NAME=${REPO_NAME}" \
|
-e "REPO_NAME=${REPO_NAME}" \
|
||||||
@@ -57,7 +57,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
'
|
'
|
||||||
|
|
||||||
wget -q --no-proxy ${fd_archive_url}
|
wget -q ${fd_archive_url}
|
||||||
tar -xf FastDeploy.tar.gz
|
tar -xf FastDeploy.tar.gz
|
||||||
rm -rf FastDeploy.tar.gz
|
rm -rf FastDeploy.tar.gz
|
||||||
cd FastDeploy
|
cd FastDeploy
|
||||||
@@ -82,9 +82,6 @@ jobs:
|
|||||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||||
FD_ZMQ_RECV_REQUEST_SERVER_PORT=$((42048 + DEVICE_PORT * 100))
|
|
||||||
FD_ZMQ_SEND_RESPONSE_SERVER_PORT=$((42038 + DEVICE_PORT * 100))
|
|
||||||
FD_ZMQ_CONTROL_CMD_SERVER_PORTS=$((42028 + DEVICE_PORT * 100))
|
|
||||||
echo "Test ENV Parameter:"
|
echo "Test ENV Parameter:"
|
||||||
echo "========================================================="
|
echo "========================================================="
|
||||||
echo "FLASK_PORT=${FLASK_PORT}"
|
echo "FLASK_PORT=${FLASK_PORT}"
|
||||||
@@ -145,7 +142,7 @@ jobs:
|
|||||||
--gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c '
|
--gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c '
|
||||||
git config --global --add safe.directory /workspace/FastDeploy
|
git config --global --add safe.directory /workspace/FastDeploy
|
||||||
cd FastDeploy
|
cd FastDeploy
|
||||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||||
python -m pip install ${fd_wheel_url}
|
python -m pip install ${fd_wheel_url}
|
||||||
bash scripts/run_pre_ce.sh
|
bash scripts/run_pre_ce.sh
|
||||||
'
|
'
|
||||||
|
|||||||
6
.github/workflows/_stable_test.yml
vendored
6
.github/workflows/_stable_test.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
|||||||
FULL_REPO="${{ github.repository }}"
|
FULL_REPO="${{ github.repository }}"
|
||||||
REPO_NAME="${FULL_REPO##*/}"
|
REPO_NAME="${FULL_REPO##*/}"
|
||||||
BASE_BRANCH="${{ github.base_ref }}"
|
BASE_BRANCH="${{ github.base_ref }}"
|
||||||
docker pull ${docker_image}
|
|
||||||
# Clean the repository directory before starting
|
# Clean the repository directory before starting
|
||||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||||
-e "REPO_NAME=${REPO_NAME}" \
|
-e "REPO_NAME=${REPO_NAME}" \
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
'
|
'
|
||||||
|
|
||||||
wget -q --no-proxy ${fd_archive_url}
|
wget -q ${fd_archive_url}
|
||||||
tar -xf FastDeploy.tar.gz
|
tar -xf FastDeploy.tar.gz
|
||||||
rm -rf FastDeploy.tar.gz
|
rm -rf FastDeploy.tar.gz
|
||||||
cd FastDeploy
|
cd FastDeploy
|
||||||
@@ -146,7 +146,7 @@ jobs:
|
|||||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||||
-e TZ="Asia/Shanghai" \
|
-e TZ="Asia/Shanghai" \
|
||||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||||
|
|
||||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||||
|
|
||||||
|
|||||||
115
.github/workflows/_unit_test_coverage.yml
vendored
115
.github/workflows/_unit_test_coverage.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
|||||||
|
|
||||||
run_tests_with_coverage:
|
run_tests_with_coverage:
|
||||||
runs-on: [self-hosted, GPU-h1z1-2Cards]
|
runs-on: [self-hosted, GPU-h1z1-2Cards]
|
||||||
timeout-minutes: 90
|
timeout-minutes: 60
|
||||||
needs: check_cov_skip
|
needs: check_cov_skip
|
||||||
if: needs.check_cov_skip.outputs.can-skip != 'true'
|
if: needs.check_cov_skip.outputs.can-skip != 'true'
|
||||||
outputs:
|
outputs:
|
||||||
@@ -60,7 +60,7 @@ jobs:
|
|||||||
FULL_REPO="${{ github.repository }}"
|
FULL_REPO="${{ github.repository }}"
|
||||||
REPO_NAME="${FULL_REPO##*/}"
|
REPO_NAME="${FULL_REPO##*/}"
|
||||||
BASE_BRANCH="${{ github.base_ref }}"
|
BASE_BRANCH="${{ github.base_ref }}"
|
||||||
docker pull ${docker_image}
|
|
||||||
# Clean the repository directory before starting
|
# Clean the repository directory before starting
|
||||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||||
-e "REPO_NAME=${REPO_NAME}" \
|
-e "REPO_NAME=${REPO_NAME}" \
|
||||||
@@ -71,7 +71,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
'
|
'
|
||||||
|
|
||||||
wget -q --no-proxy ${fd_archive_url}
|
wget -q ${fd_archive_url}
|
||||||
tar -xf FastDeploy.tar.gz
|
tar -xf FastDeploy.tar.gz
|
||||||
rm -rf FastDeploy.tar.gz
|
rm -rf FastDeploy.tar.gz
|
||||||
cd FastDeploy
|
cd FastDeploy
|
||||||
@@ -168,10 +168,13 @@ jobs:
|
|||||||
git config --global --add safe.directory /workspace/FastDeploy
|
git config --global --add safe.directory /workspace/FastDeploy
|
||||||
cd FastDeploy
|
cd FastDeploy
|
||||||
git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt
|
git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt
|
||||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||||
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||||
|
|
||||||
python -m pip install -r scripts/unittest_requirement.txt
|
python -m pip install coverage
|
||||||
|
python -m pip install diff-cover
|
||||||
|
python -m pip install pytest-cov
|
||||||
|
python -m pip install jsonschema aistudio_sdk==0.3.5
|
||||||
python -m pip install ${fd_wheel_url}
|
python -m pip install ${fd_wheel_url}
|
||||||
rm -rf fastdeploy
|
rm -rf fastdeploy
|
||||||
# coverage subprocess use
|
# coverage subprocess use
|
||||||
@@ -194,104 +197,51 @@ jobs:
|
|||||||
coverage xml -o python_coverage_all.xml
|
coverage xml -o python_coverage_all.xml
|
||||||
COVERAGE_EXIT_CODE=0
|
COVERAGE_EXIT_CODE=0
|
||||||
if [[ "$IS_PR" == "true" ]]; then
|
if [[ "$IS_PR" == "true" ]]; then
|
||||||
echo "Running diff coverage for PR..."
|
|
||||||
diff-cover python_coverage_all.xml --diff-file=diff.txt --fail-under=80 --json-report diff_coverage.json || COVERAGE_EXIT_CODE=9
|
diff-cover python_coverage_all.xml --diff-file=diff.txt --fail-under=80 --json-report diff_coverage.json || COVERAGE_EXIT_CODE=9
|
||||||
python scripts/generate_diff_coverage_xml.py diff.txt python_coverage_all.xml
|
python scripts/generate_diff_coverage_xml.py diff.txt python_coverage_all.xml
|
||||||
else
|
else
|
||||||
echo "Running full coverage"
|
echo "Not a PR, skipping diff-cover"
|
||||||
coverage report -m > full_coverage_report.txt
|
|
||||||
python scripts/generate_full_coverage_csv.py full_coverage_report.txt full_coverage_report.csv
|
|
||||||
fi
|
fi
|
||||||
echo "COVERAGE_EXIT_CODE=${COVERAGE_EXIT_CODE}" >> exit_code.env
|
echo "COVERAGE_EXIT_CODE=${COVERAGE_EXIT_CODE}" >> exit_code.env
|
||||||
'
|
'
|
||||||
if [ -f FastDeploy/exit_code.env ]; then
|
if [ -f FastDeploy/exit_code.env ]; then
|
||||||
cat FastDeploy/exit_code.env >> $GITHUB_ENV
|
cat FastDeploy/exit_code.env >> $GITHUB_ENV
|
||||||
fi
|
fi
|
||||||
- name: Upload coverage and unit test results to BOS
|
- name: Upload unit resule and diff coverage to bos
|
||||||
id: cov_upload
|
id: cov_upload
|
||||||
shell: bash
|
shell: bash
|
||||||
env:
|
|
||||||
IS_PR: ${{ github.event_name == 'pull_request' }}
|
|
||||||
GITHUB_SHA: ${{ github.sha }}
|
|
||||||
BRANCH: ${{ github.ref_name }}
|
|
||||||
PR_COMMIT_SHA: ${{ github.event.pull_request.head.sha }}
|
|
||||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
|
||||||
run: |
|
run: |
|
||||||
cd FastDeploy
|
cd FastDeploy
|
||||||
python -m pip install -q bce-python-sdk==0.9.29
|
commit_id=${{ github.event.pull_request.head.sha }}
|
||||||
wget -q --no-proxy --no-check-certificate \
|
pr_num=${{ github.event.pull_request.number }}
|
||||||
https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py \
|
target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id}/SM${compile_arch//,/_}
|
||||||
-O bos_tools.py
|
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py -O bos_tools.py
|
||||||
push_file=$(realpath bos_tools.py)
|
push_file=$(realpath bos_tools.py)
|
||||||
|
python -m pip install bce-python-sdk==0.9.29
|
||||||
if [[ "$IS_PR" == "true" ]]; then
|
diff_cov_file="diff_coverage.xml"
|
||||||
commit_id=${PR_COMMIT_SHA}
|
if [ -f ${diff_cov_file} ];then
|
||||||
pr_num=${PR_NUMBER}
|
python ${push_file} ${diff_cov_file} ${target_path}/CoverageData
|
||||||
target_path=paddle-github-action/PR/FastDeploy/${pr_num}/${commit_id}/SM${compile_arch//,/_}
|
target_path_stripped="${target_path#paddle-github-action/}"
|
||||||
elif [[ "${{ github.ref_type }}" == "tag" ]]; then
|
DIFF_COV_FILE_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${diff_cov_file}
|
||||||
commit_id=${{ github.sha }}
|
echo "diff_cov_file_url=${DIFF_COV_FILE_URL}" >> $GITHUB_OUTPUT
|
||||||
tag_name=${{ github.ref_name }}
|
echo "diff_cov_file_url=${DIFF_COV_FILE_URL}" >> $GITHUB_ENV
|
||||||
target_path=paddle-github-action/TAG/FastDeploy/${tag_name}/${commit_id}/SM${compile_arch//,/_}
|
|
||||||
target_path_latest=paddle-github-action/TAG/FastDeploy/${tag_name}/latest/SM${compile_arch//,/_}
|
|
||||||
target_path_stripped_latest="${target_path_latest#paddle-github-action/}"
|
|
||||||
else
|
|
||||||
commit_id=${{ github.sha }}
|
|
||||||
branch_name=${{ github.ref_name }}
|
|
||||||
target_path=paddle-github-action/BRANCH/FastDeploy/${branch_name}/${commit_id}/SM${compile_arch//,/_}
|
|
||||||
target_path_latest=paddle-github-action/BRANCH/FastDeploy/${branch_name}/latest/SM${compile_arch//,/_}
|
|
||||||
target_path_stripped_latest="${target_path_latest#paddle-github-action/}"
|
|
||||||
fi
|
fi
|
||||||
|
diff_cov_result_json="diff_coverage.json"
|
||||||
target_path_stripped="${target_path#paddle-github-action/}"
|
if [ -f ${diff_cov_result_json} ];then
|
||||||
|
python ${push_file} ${diff_cov_result_json} ${target_path}/CoverageData
|
||||||
if [[ "$IS_PR" == "true" ]]; then
|
target_path_stripped="${target_path#paddle-github-action/}"
|
||||||
diff_cov_file="diff_coverage.xml"
|
DIFF_COV_JSON_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${diff_cov_result_json}
|
||||||
if [ -f ${diff_cov_file} ]; then
|
echo "diff_cov_result_json_url=${DIFF_COV_JSON_URL}" >> $GITHUB_OUTPUT
|
||||||
python ${push_file} ${diff_cov_file} ${target_path}/CoverageData
|
echo "diff_cov_result_json_url=${DIFF_COV_JSON_URL}" >> $GITHUB_ENV
|
||||||
DIFF_COV_FILE_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${diff_cov_file}
|
|
||||||
echo "diff_cov_file_url=${DIFF_COV_FILE_URL}" >> $GITHUB_OUTPUT
|
|
||||||
echo "diff_cov_file_url=${DIFF_COV_FILE_URL}" >> $GITHUB_ENV
|
|
||||||
fi
|
|
||||||
|
|
||||||
diff_cov_result_json="diff_coverage.json"
|
|
||||||
if [ -f ${diff_cov_result_json} ]; then
|
|
||||||
python ${push_file} ${diff_cov_result_json} ${target_path}/CoverageData
|
|
||||||
DIFF_COV_JSON_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${diff_cov_result_json}
|
|
||||||
echo "diff_cov_result_json_url=${DIFF_COV_JSON_URL}" >> $GITHUB_OUTPUT
|
|
||||||
echo "diff_cov_result_json_url=${DIFF_COV_JSON_URL}" >> $GITHUB_ENV
|
|
||||||
fi
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
HAS_FAILED_TESTS=false
|
|
||||||
unittest_result="failed_tests.log"
|
unittest_result="failed_tests.log"
|
||||||
if [ -s ${unittest_result} ]; then
|
if [ -s ${unittest_result} ];then
|
||||||
HAS_FAILED_TESTS=true
|
|
||||||
python ${push_file} ${unittest_result} ${target_path}/UnitTestResult
|
python ${push_file} ${unittest_result} ${target_path}/UnitTestResult
|
||||||
|
target_path_stripped="${target_path#paddle-github-action/}"
|
||||||
UNIT_TEST_RESULT_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/UnitTestResult/${unittest_result}
|
UNIT_TEST_RESULT_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/UnitTestResult/${unittest_result}
|
||||||
echo "unittest_failed_url=${UNIT_TEST_RESULT_URL}" >> $GITHUB_OUTPUT
|
echo "unittest_failed_url=${UNIT_TEST_RESULT_URL}" >> $GITHUB_OUTPUT
|
||||||
echo "unittest_failed_url=${UNIT_TEST_RESULT_URL}" >> $GITHUB_ENV
|
echo "unittest_failed_url=${UNIT_TEST_RESULT_URL}" >> $GITHUB_ENV
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ "$IS_PR" != "true" ]]; then
|
|
||||||
full_cov_file="full_coverage_report.txt"
|
|
||||||
full_cov_csv="full_coverage_report.csv"
|
|
||||||
|
|
||||||
if [ -f ${full_cov_file} ]; then
|
|
||||||
python ${push_file} ${full_cov_file} ${target_path}/CoverageData
|
|
||||||
python ${push_file} ${full_cov_file} ${target_path_latest}/CoverageData
|
|
||||||
FULL_COV_FILE_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${full_cov_file}
|
|
||||||
echo "full_coverage_report_url=${FULL_COV_FILE_URL}" >> $GITHUB_OUTPUT
|
|
||||||
echo "full_coverage_report_url=${FULL_COV_FILE_URL}" >> $GITHUB_ENV
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ "$HAS_FAILED_TESTS" = false ] && [ -f ${full_cov_csv} ]; then
|
|
||||||
python ${push_file} ${full_cov_csv} ${target_path}/CoverageData
|
|
||||||
python ${push_file} ${full_cov_csv} ${target_path_latest}/CoverageData
|
|
||||||
FULL_COV_CSV_URL=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/CoverageData/${full_cov_csv}
|
|
||||||
echo "full_coverage_csv_url=${FULL_COV_CSV_URL}" >> $GITHUB_OUTPUT
|
|
||||||
echo "full_coverage_csv_url=${FULL_COV_CSV_URL}" >> $GITHUB_ENV
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
- name: Check Unit Test Success
|
- name: Check Unit Test Success
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
@@ -344,7 +294,6 @@ jobs:
|
|||||||
needs: run_tests_with_coverage
|
needs: run_tests_with_coverage
|
||||||
if: always()
|
if: always()
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
timeout-minutes: 15
|
|
||||||
env:
|
env:
|
||||||
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }}
|
||||||
steps:
|
steps:
|
||||||
@@ -353,7 +302,7 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
diff_cov_file_url: ${{ needs.run_tests_with_coverage.outputs.diff_cov_file_url }}
|
diff_cov_file_url: ${{ needs.run_tests_with_coverage.outputs.diff_cov_file_url }}
|
||||||
run: |
|
run: |
|
||||||
wget -q --no-proxy ${fd_archive_url}
|
wget ${fd_archive_url}
|
||||||
tar -xf FastDeploy.tar.gz
|
tar -xf FastDeploy.tar.gz
|
||||||
cd FastDeploy
|
cd FastDeploy
|
||||||
if [ -z "${diff_cov_file_url}" ]; then
|
if [ -z "${diff_cov_file_url}" ]; then
|
||||||
|
|||||||
20
.github/workflows/ce_job.yml
vendored
20
.github/workflows/ce_job.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
permissions: read-all
|
permissions: read-all
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: CE-Job-${{ github.ref }}-${{ github.sha }}
|
group: ${{ github.ref }}-${{ github.sha }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
@@ -199,13 +199,13 @@ jobs:
|
|||||||
ls
|
ls
|
||||||
python ${push_file} ${filename} ${target_path}
|
python ${push_file} ${filename} ${target_path}
|
||||||
target_path_stripped="${target_path#paddle-qa/}"
|
target_path_stripped="${target_path#paddle-qa/}"
|
||||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${filename}
|
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name}
|
||||||
|
echo "commit wheel url is ${WHEEL_PATH}"
|
||||||
|
|
||||||
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
|
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
|
||||||
python ${push_file} ${filename} ${target_path_latest}
|
python ${push_file} ${filename} ${target_path_latest}
|
||||||
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
|
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
|
||||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${filename}
|
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${fd_wheel_name}
|
||||||
echo "commit wheel url is ${WHEEL_PATH}"
|
|
||||||
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
|
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
|
||||||
|
|
||||||
ce_upload_sm8689:
|
ce_upload_sm8689:
|
||||||
@@ -224,9 +224,9 @@ jobs:
|
|||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
- name: Wheel Info Show and Upload
|
- name: Wheel Info Show and Upload
|
||||||
run: |
|
run: |
|
||||||
echo "The wheel is located at: ${{ needs.build_sm8689.outputs.wheel_path }}"
|
echo "The wheel is located at: ${{ needs.build_sm8090.outputs.wheel_path }}"
|
||||||
wget -q --no-check-certificate ${{ needs.build_sm8689.outputs.wheel_path }}
|
wget -q --no-check-certificate ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||||
filename=$(basename ${{ needs.build_sm8689.outputs.wheel_path }})
|
filename=$(basename ${{ needs.build_sm8090.outputs.wheel_path }})
|
||||||
|
|
||||||
commit_id=${{ github.sha }}
|
commit_id=${{ github.sha }}
|
||||||
branch_name=${{ github.ref_name }}
|
branch_name=${{ github.ref_name }}
|
||||||
@@ -238,11 +238,11 @@ jobs:
|
|||||||
ls
|
ls
|
||||||
python ${push_file} ${filename} ${target_path}
|
python ${push_file} ${filename} ${target_path}
|
||||||
target_path_stripped="${target_path#paddle-qa/}"
|
target_path_stripped="${target_path#paddle-qa/}"
|
||||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${filename}
|
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name}
|
||||||
|
echo "commit wheel url is ${WHEEL_PATH}"
|
||||||
|
|
||||||
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
|
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
|
||||||
python ${push_file} ${filename} ${target_path_latest}
|
python ${push_file} ${filename} ${target_path_latest}
|
||||||
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
|
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
|
||||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${filename}
|
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${fd_wheel_name}
|
||||||
echo "commit wheel url is ${WHEEL_PATH}"
|
|
||||||
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
|
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
name: CI_GCU
|
name: CI_GCU
|
||||||
|
|
||||||
on:
|
on:
|
||||||
#pull_request:
|
pull_request:
|
||||||
#branches:
|
branches:
|
||||||
#- develop
|
- develop
|
||||||
#- 'release/*'
|
- 'release/*'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
6
.github/workflows/ci_iluvatar.yml
vendored
6
.github/workflows/ci_iluvatar.yml
vendored
@@ -28,22 +28,18 @@ jobs:
|
|||||||
REPO="https://github.com/${{ github.repository }}.git"
|
REPO="https://github.com/${{ github.repository }}.git"
|
||||||
FULL_REPO="${{ github.repository }}"
|
FULL_REPO="${{ github.repository }}"
|
||||||
REPO_NAME="${FULL_REPO##*/}"
|
REPO_NAME="${FULL_REPO##*/}"
|
||||||
BASE_BRANCH="${{ github.base_ref }}"
|
|
||||||
# Clean the repository directory before starting
|
# Clean the repository directory before starting
|
||||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||||
-e "REPO_NAME=${REPO_NAME}" \
|
-e "REPO_NAME=${REPO_NAME}" \
|
||||||
-e "BASE_BRANCH=${BASE_BRANCH}" \
|
|
||||||
${docker_image} /bin/bash -c '
|
${docker_image} /bin/bash -c '
|
||||||
if [ -d ${REPO_NAME} ]; then
|
if [ -d ${REPO_NAME} ]; then
|
||||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||||
rm -rf ${REPO_NAME}
|
rm -rf ${REPO_NAME}
|
||||||
fi
|
fi
|
||||||
'
|
'
|
||||||
git config --global http.proxy "http://61.151.249.150:33128"
|
|
||||||
git config --global https.proxy "http://61.151.249.150:33128"
|
|
||||||
git config --global user.name "FastDeployCI"
|
git config --global user.name "FastDeployCI"
|
||||||
git config --global user.email "fastdeploy_ci@example.com"
|
git config --global user.email "fastdeploy_ci@example.com"
|
||||||
git clone --recursive ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
|
git clone ${REPO} ${REPO_NAME}
|
||||||
cd FastDeploy
|
cd FastDeploy
|
||||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||||
|
|||||||
174
.github/workflows/ci_image_update.yml
vendored
174
.github/workflows/ci_image_update.yml
vendored
@@ -1,174 +0,0 @@
|
|||||||
name: CI Images Build
|
|
||||||
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
schedule:
|
|
||||||
- cron: '0 18 * * *' # 2:00 AM China Standard Time (UTC+8)
|
|
||||||
|
|
||||||
permissions: read-all
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: CI-Images-Build-${{ github.ref }}-${{ github.sha }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
clone:
|
|
||||||
environment: CodeSync
|
|
||||||
name: FD-Clone-Linux
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
outputs:
|
|
||||||
repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }}
|
|
||||||
steps:
|
|
||||||
- name: Clone FastDeploy
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
ref: ${{ github.ref_name }}
|
|
||||||
submodules: 'recursive'
|
|
||||||
fetch-depth: 1000
|
|
||||||
|
|
||||||
- name: Python Setup
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
- name: Code Info Show and Upload
|
|
||||||
id: set_output
|
|
||||||
env:
|
|
||||||
AK: ${{ secrets.BOS_AK }}
|
|
||||||
SK: ${{ secrets.BOS_SK }}
|
|
||||||
run: |
|
|
||||||
git config --unset http.https://github.com/.extraheader
|
|
||||||
git submodule foreach --recursive sh -c "git config --local --unset-all 'http.https://github.com/.extraheader'"
|
|
||||||
git submodule foreach --recursive sh -c "git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'"
|
|
||||||
echo "Current HEAD Log:"
|
|
||||||
git log --oneline -n 5
|
|
||||||
ls
|
|
||||||
cd ..
|
|
||||||
tar -zcf FastDeploy.tar.gz FastDeploy
|
|
||||||
if [[ "${{ github.ref_type }}" == "tag" ]]; then
|
|
||||||
commit_id=${{ github.sha }}
|
|
||||||
tag_name=${{ github.ref_name }}
|
|
||||||
target_path=paddle-qa/TAG/FastDeploy/${tag_name}/${commit_id}
|
|
||||||
else
|
|
||||||
commit_id=${{ github.sha }}
|
|
||||||
branch_name=${{ github.ref_name }}
|
|
||||||
target_path=paddle-qa/BRANCH/FastDeploy/${branch_name}/${commit_id}
|
|
||||||
fi
|
|
||||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
|
||||||
push_file=$(realpath bos_tools.py)
|
|
||||||
python -m pip install bce-python-sdk==0.9.29
|
|
||||||
ls
|
|
||||||
python ${push_file} FastDeploy.tar.gz ${target_path}
|
|
||||||
target_path_stripped="${target_path#paddle-qa/}"
|
|
||||||
REPO_ARCHIVE_URL=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/FastDeploy.tar.gz
|
|
||||||
echo "repo_archive_url=${REPO_ARCHIVE_URL}" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
resultshow:
|
|
||||||
name: Show Code Archive Output
|
|
||||||
needs: clone
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Print wheel path
|
|
||||||
run: |
|
|
||||||
echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}"
|
|
||||||
|
|
||||||
ci_image_build:
|
|
||||||
name: CI Images Build
|
|
||||||
needs: clone
|
|
||||||
uses: ./.github/workflows/_ci_image_build.yml
|
|
||||||
with:
|
|
||||||
CI_DOCKER_IMAGE_NAME: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate-precheck
|
|
||||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
|
||||||
|
|
||||||
|
|
||||||
build_sm8090:
|
|
||||||
name: BUILD_SM8090
|
|
||||||
needs: [clone, ci_image_build]
|
|
||||||
uses: ./.github/workflows/_build_linux.yml
|
|
||||||
with:
|
|
||||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
|
||||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
|
||||||
COMPILE_ARCH: "90"
|
|
||||||
WITH_NIGHTLY_BUILD: ${{ needs.publish_pre_check.outputs.with_nightly_build }}
|
|
||||||
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
|
|
||||||
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
|
|
||||||
PADDLE_WHL_URL: ${{ needs.publish_pre_check.outputs.compile_use_paddle_whl_url }}
|
|
||||||
|
|
||||||
|
|
||||||
unittest_coverage:
|
|
||||||
name: Run FastDeploy Unit Tests and Coverage
|
|
||||||
needs: [clone,build_sm8090,ci_image_build]
|
|
||||||
uses: ./.github/workflows/_unit_test_coverage.yml
|
|
||||||
with:
|
|
||||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
|
||||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
|
||||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
|
||||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
|
||||||
secrets:
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
logprob_test:
|
|
||||||
name: Run FastDeploy LogProb Tests
|
|
||||||
needs: [build_sm8090,ci_image_build]
|
|
||||||
uses: ./.github/workflows/_logprob_test_linux.yml
|
|
||||||
with:
|
|
||||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
|
||||||
PADDLETEST_ARCHIVE_URL: "https://xly-devops.bj.bcebos.com/PaddleTest/PaddleTest.tar.gz"
|
|
||||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
|
||||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
|
||||||
|
|
||||||
pre_ce_test:
|
|
||||||
name: Extracted partial CE model tasks to run in CI.
|
|
||||||
needs: [clone,build_sm8090,ci_image_build]
|
|
||||||
uses: ./.github/workflows/_pre_ce_test.yml
|
|
||||||
with:
|
|
||||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
|
||||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
|
||||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
|
||||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
|
||||||
|
|
||||||
base_test:
|
|
||||||
name: Run Base Tests
|
|
||||||
needs: [clone,build_sm8090,ci_image_build]
|
|
||||||
uses: ./.github/workflows/_base_test.yml
|
|
||||||
with:
|
|
||||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
|
||||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
|
||||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
|
||||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
|
||||||
|
|
||||||
accuracy_test:
|
|
||||||
name: Run Accuracy Tests
|
|
||||||
needs: [clone,build_sm8090,ci_image_build]
|
|
||||||
uses: ./.github/workflows/_accuracy_test.yml
|
|
||||||
with:
|
|
||||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
|
||||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
|
||||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
|
||||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
|
||||||
|
|
||||||
stable_test:
|
|
||||||
name: Run Stable Tests
|
|
||||||
needs: [clone,build_sm8090,ci_image_build]
|
|
||||||
uses: ./.github/workflows/_stable_test.yml
|
|
||||||
with:
|
|
||||||
DOCKER_IMAGE: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
|
||||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
|
||||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
|
||||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
|
||||||
|
|
||||||
|
|
||||||
publish_pre_check:
|
|
||||||
name: Publish Docker Images Pre Check
|
|
||||||
needs: [ci_image_build, unittest_coverage,logprob_test,pre_ce_test,base_test,accuracy_test,stable_test]
|
|
||||||
runs-on: [self-hosted, Docker-Build]
|
|
||||||
steps:
|
|
||||||
- name: Images Uploading
|
|
||||||
env:
|
|
||||||
images_name: ${{ needs.ci_image_build.outputs.docker_name_precheck }}
|
|
||||||
ci_image_name: "ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate"
|
|
||||||
run: |
|
|
||||||
echo "images_name=${images_name}"
|
|
||||||
docker images ${ci_image_name}
|
|
||||||
docker tag ${images_name} ${ci_image_name}
|
|
||||||
docker push ${ci_image_name}
|
|
||||||
4
.github/workflows/ci_xpu.yml
vendored
4
.github/workflows/ci_xpu.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Code Checkout
|
- name: Code Checkout
|
||||||
env:
|
env:
|
||||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.2.0
|
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.1.0
|
||||||
run: |
|
run: |
|
||||||
REPO="https://github.com/${{ github.repository }}.git"
|
REPO="https://github.com/${{ github.repository }}.git"
|
||||||
FULL_REPO="${{ github.repository }}"
|
FULL_REPO="${{ github.repository }}"
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Run CI unittest
|
- name: Run CI unittest
|
||||||
env:
|
env:
|
||||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.2.0
|
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.1.0
|
||||||
run: |
|
run: |
|
||||||
runner_name="${{ runner.name }}"
|
runner_name="${{ runner.name }}"
|
||||||
last_char="${runner_name: -1}"
|
last_char="${runner_name: -1}"
|
||||||
|
|||||||
2
.github/workflows/pr_build_and_test.yml
vendored
2
.github/workflows/pr_build_and_test.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
||||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||||
COMPILE_ARCH: "90"
|
COMPILE_ARCH: "89,90"
|
||||||
WITH_NIGHTLY_BUILD: "OFF"
|
WITH_NIGHTLY_BUILD: "OFF"
|
||||||
FD_VERSION: "0.0.0"
|
FD_VERSION: "0.0.0"
|
||||||
|
|
||||||
|
|||||||
62
.github/workflows/publish_job.yml
vendored
62
.github/workflows/publish_job.yml
vendored
@@ -13,7 +13,7 @@ on:
|
|||||||
permissions: read-all
|
permissions: read-all
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: Publish-Job-${{ github.ref }}-${{ github.sha }}
|
group: ${{ github.ref }}-${{ github.sha }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
|
||||||
@@ -268,56 +268,6 @@ jobs:
|
|||||||
ls
|
ls
|
||||||
python ${push_file} ${filename} ${target_path}
|
python ${push_file} ${filename} ${target_path}
|
||||||
|
|
||||||
images_build:
|
|
||||||
name: Run FD Image Build
|
|
||||||
needs: [clone, publish_pre_check, build_sm8090]
|
|
||||||
runs-on: [self-hosted, Docker-Build]
|
|
||||||
if: |
|
|
||||||
github.event.repository.fork == false &&
|
|
||||||
(
|
|
||||||
(github.event_name == 'push' && github.ref_type == 'tag') ||
|
|
||||||
(github.event_name == 'workflow_dispatch' && github.ref_type == 'tag')
|
|
||||||
)
|
|
||||||
env:
|
|
||||||
FD_VERSION: ${{ needs.publish_pre_check.outputs.fd_version }}
|
|
||||||
PADDLEVERSION: ${{ needs.publish_pre_check.outputs.compile_use_paddle_version }}
|
|
||||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
|
||||||
steps:
|
|
||||||
- name: Images Build
|
|
||||||
shell: bash
|
|
||||||
env:
|
|
||||||
docker_image: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
|
||||||
fd_archive_url: ${FASTDEPLOY_ARCHIVE_URL}
|
|
||||||
run: |
|
|
||||||
set -x
|
|
||||||
FULL_REPO="${{ github.repository }}"
|
|
||||||
REPO_NAME="${FULL_REPO##*/}"
|
|
||||||
|
|
||||||
# Clean the repository directory before starting
|
|
||||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
|
||||||
-e "REPO_NAME=${REPO_NAME}" \
|
|
||||||
${docker_image} /bin/bash -c '
|
|
||||||
if [ -d ${REPO_NAME} ]; then
|
|
||||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
|
||||||
rm -rf ${REPO_NAME}*
|
|
||||||
fi
|
|
||||||
'
|
|
||||||
wget -q --no-proxy ${fd_archive_url}
|
|
||||||
tar -xf FastDeploy.tar.gz
|
|
||||||
rm -rf FastDeploy.tar.gz
|
|
||||||
cd FastDeploy
|
|
||||||
git config --global user.name "FastDeployCI"
|
|
||||||
git config --global user.email "fastdeploy_ci@example.com"
|
|
||||||
git log -n 3 --oneline
|
|
||||||
|
|
||||||
PRODUCT_NAME=ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:${FD_VERSION}
|
|
||||||
docker build --no-cache -t ${PRODUCT_NAME} -f Dockerfile.gpu . \
|
|
||||||
--network host \
|
|
||||||
--build-arg PADDLE_VERSION=${PADDLEVERSION} \
|
|
||||||
--build-arg FD_VERSION=${FD_VERSION}
|
|
||||||
|
|
||||||
docker push ${PRODUCT_NAME}
|
|
||||||
|
|
||||||
unittest_coverage:
|
unittest_coverage:
|
||||||
name: Run FastDeploy Unit Tests and Coverage
|
name: Run FastDeploy Unit Tests and Coverage
|
||||||
needs: [clone,build_sm8090]
|
needs: [clone,build_sm8090]
|
||||||
@@ -369,13 +319,3 @@ jobs:
|
|||||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
||||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
||||||
|
|
||||||
stable_test:
|
|
||||||
name: Run Stable Tests
|
|
||||||
needs: [clone,build_sm8090]
|
|
||||||
uses: ./.github/workflows/_stable_test.yml
|
|
||||||
with:
|
|
||||||
DOCKER_IMAGE: ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddleqa:fastdeploy-ciuse-cuda126-dailyupdate
|
|
||||||
FASTDEPLOY_ARCHIVE_URL: ${{ needs.clone.outputs.repo_archive_url }}
|
|
||||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
|
||||||
MODEL_CACHE_DIR: "/ssd2/actions-runner/ModelData"
|
|
||||||
|
|||||||
157
.github/workflows/rerun.yml
vendored
157
.github/workflows/rerun.yml
vendored
@@ -1,157 +0,0 @@
|
|||||||
name: Re-run
|
|
||||||
|
|
||||||
on:
|
|
||||||
issue_comment:
|
|
||||||
types: [created]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
re-run:
|
|
||||||
if: ${{ github.event.issue.pull_request && contains(github.event.comment.body, '/re-run') && github.event.comment.user.login == github.event.issue.user.login }}
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Cleanup
|
|
||||||
run: |
|
|
||||||
rm -rf * .[^.]*
|
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v5
|
|
||||||
|
|
||||||
- name: Rerun all failed jobs
|
|
||||||
if: ${{ contains(github.event.comment.body, 'all-failed') }}
|
|
||||||
uses: ./.github/actions/rerun-workflow
|
|
||||||
with:
|
|
||||||
PR_ID: ${{ github.event.issue.number }}
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
OWNER: ${{ github.repository_owner }}
|
|
||||||
REPO: ${{ github.event.repository.name }}
|
|
||||||
JOB_NAME: 'all-failed'
|
|
||||||
|
|
||||||
- name: Rerun Approval
|
|
||||||
if: ${{ contains(github.event.comment.body, 'approval') }}
|
|
||||||
uses: ./.github/actions/rerun-workflow
|
|
||||||
with:
|
|
||||||
PR_ID: ${{ github.event.issue.number }}
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
OWNER: ${{ github.repository_owner }}
|
|
||||||
REPO: ${{ github.event.repository.name }}
|
|
||||||
JOB_NAME: 'Approval'
|
|
||||||
|
|
||||||
- name: Rerun CI_ILUVATAR
|
|
||||||
if: ${{ contains(github.event.comment.body, 'ci_iluvatar') }}
|
|
||||||
uses: ./.github/actions/rerun-workflow
|
|
||||||
with:
|
|
||||||
PR_ID: ${{ github.event.issue.number }}
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
OWNER: ${{ github.repository_owner }}
|
|
||||||
REPO: ${{ github.event.repository.name }}
|
|
||||||
JOB_NAME: 'CI_ILUVATAR'
|
|
||||||
|
|
||||||
- name: Rerun CI_XPU
|
|
||||||
if: ${{ contains(github.event.comment.body, 'ci_xpu') }}
|
|
||||||
uses: ./.github/actions/rerun-workflow
|
|
||||||
with:
|
|
||||||
PR_ID: ${{ github.event.issue.number }}
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
OWNER: ${{ github.repository_owner }}
|
|
||||||
REPO: ${{ github.event.repository.name }}
|
|
||||||
JOB_NAME: 'CI_XPU'
|
|
||||||
|
|
||||||
- name: Rerun Codestyle-check
|
|
||||||
if: ${{ contains(github.event.comment.body, 'codestyle') || contains(github.event.comment.body, 'pre_commit') }}
|
|
||||||
uses: ./.github/actions/rerun-workflow
|
|
||||||
with:
|
|
||||||
PR_ID: ${{ github.event.issue.number }}
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
OWNER: ${{ github.repository_owner }}
|
|
||||||
REPO: ${{ github.event.repository.name }}
|
|
||||||
JOB_NAME: 'Pre Commit'
|
|
||||||
|
|
||||||
- name: Rerun Clone
|
|
||||||
if: ${{ contains(github.event.comment.body, 'clone') }}
|
|
||||||
uses: ./.github/actions/rerun-workflow
|
|
||||||
with:
|
|
||||||
PR_ID: ${{ github.event.issue.number }}
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
OWNER: ${{ github.repository_owner }}
|
|
||||||
REPO: ${{ github.event.repository.name }}
|
|
||||||
JOB_NAME: 'FD-Clone-Linux / code-clone'
|
|
||||||
|
|
||||||
- name: Rerun Build
|
|
||||||
if: ${{ contains(github.event.comment.body, 'build') }}
|
|
||||||
uses: ./.github/actions/rerun-workflow
|
|
||||||
with:
|
|
||||||
PR_ID: ${{ github.event.issue.number }}
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
OWNER: ${{ github.repository_owner }}
|
|
||||||
REPO: ${{ github.event.repository.name }}
|
|
||||||
JOB_NAME: 'FD-Build-Linux / fd-build'
|
|
||||||
|
|
||||||
- name: Rerun run_ce_cases
|
|
||||||
if: ${{ contains(github.event.comment.body, 'run_ce_cases') }}
|
|
||||||
uses: ./.github/actions/rerun-workflow
|
|
||||||
with:
|
|
||||||
PR_ID: ${{ github.event.issue.number }}
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
OWNER: ${{ github.repository_owner }}
|
|
||||||
REPO: ${{ github.event.repository.name }}
|
|
||||||
JOB_NAME: 'Extracted partial CE model tasks to run in CI. / run_ce_cases'
|
|
||||||
|
|
||||||
- name: Rerun accuracy_tests
|
|
||||||
if: ${{ contains(github.event.comment.body, 'accuracy_tests') }}
|
|
||||||
uses: ./.github/actions/rerun-workflow
|
|
||||||
with:
|
|
||||||
PR_ID: ${{ github.event.issue.number }}
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
OWNER: ${{ github.repository_owner }}
|
|
||||||
REPO: ${{ github.event.repository.name }}
|
|
||||||
JOB_NAME: 'Run Accuracy Tests / accuracy_tests'
|
|
||||||
|
|
||||||
- name: Rerun base_tests
|
|
||||||
if: ${{ contains(github.event.comment.body, 'base_tests') }}
|
|
||||||
uses: ./.github/actions/rerun-workflow
|
|
||||||
with:
|
|
||||||
PR_ID: ${{ github.event.issue.number }}
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
OWNER: ${{ github.repository_owner }}
|
|
||||||
REPO: ${{ github.event.repository.name }}
|
|
||||||
JOB_NAME: 'Run Base Tests / base_tests'
|
|
||||||
|
|
||||||
- name: Rerun run_tests_logprob
|
|
||||||
if: ${{ contains(github.event.comment.body, 'run_tests_logprob') }}
|
|
||||||
uses: ./.github/actions/rerun-workflow
|
|
||||||
with:
|
|
||||||
PR_ID: ${{ github.event.issue.number }}
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
OWNER: ${{ github.repository_owner }}
|
|
||||||
REPO: ${{ github.event.repository.name }}
|
|
||||||
JOB_NAME: 'Run FastDeploy LogProb Tests / run_tests_logprob'
|
|
||||||
|
|
||||||
- name: Rerun run_tests_with_coverage
|
|
||||||
if: ${{ contains(github.event.comment.body, 'run_tests_with_coverage') }}
|
|
||||||
uses: ./.github/actions/rerun-workflow
|
|
||||||
with:
|
|
||||||
PR_ID: ${{ github.event.issue.number }}
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
OWNER: ${{ github.repository_owner }}
|
|
||||||
REPO: ${{ github.event.repository.name }}
|
|
||||||
JOB_NAME: 'Run FastDeploy Unit Tests and Coverage / run_tests_with_coverage'
|
|
||||||
|
|
||||||
- name: Rerun diff_coverage_report
|
|
||||||
if: ${{ contains(github.event.comment.body, 'diff_coverage_report') }}
|
|
||||||
uses: ./.github/actions/rerun-workflow
|
|
||||||
with:
|
|
||||||
PR_ID: ${{ github.event.issue.number }}
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
OWNER: ${{ github.repository_owner }}
|
|
||||||
REPO: ${{ github.event.repository.name }}
|
|
||||||
JOB_NAME: 'Run FastDeploy Unit Tests and Coverage / diff_coverage_report'
|
|
||||||
|
|
||||||
- name: Rerun stable_tests
|
|
||||||
if: ${{ contains(github.event.comment.body, 'stable_tests') }}
|
|
||||||
uses: ./.github/actions/rerun-workflow
|
|
||||||
with:
|
|
||||||
PR_ID: ${{ github.event.issue.number }}
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
OWNER: ${{ github.repository_owner }}
|
|
||||||
REPO: ${{ github.event.repository.name }}
|
|
||||||
JOB_NAME: 'Run Stable Tests / stable_tests'
|
|
||||||
1
.gitmodules
vendored
1
.gitmodules
vendored
@@ -1,7 +1,6 @@
|
|||||||
[submodule "custom_ops/third_party/DeepGEMM"]
|
[submodule "custom_ops/third_party/DeepGEMM"]
|
||||||
path = custom_ops/third_party/DeepGEMM
|
path = custom_ops/third_party/DeepGEMM
|
||||||
url = https://github.com/deepseek-ai/DeepGEMM.git
|
url = https://github.com/deepseek-ai/DeepGEMM.git
|
||||||
ignore = all
|
|
||||||
[submodule "custom_ops/third_party/cutlass"]
|
[submodule "custom_ops/third_party/cutlass"]
|
||||||
path = custom_ops/third_party/cutlass
|
path = custom_ops/third_party/cutlass
|
||||||
url = https://github.com/NVIDIA/cutlass.git
|
url = https://github.com/NVIDIA/cutlass.git
|
||||||
|
|||||||
@@ -1,7 +1,3 @@
|
|||||||
exclude: |
|
|
||||||
(?x)^(
|
|
||||||
dockerfiles/.+
|
|
||||||
)$
|
|
||||||
default_install_hook_types:
|
default_install_hook_types:
|
||||||
- pre-commit
|
- pre-commit
|
||||||
- commit-msg
|
- commit-msg
|
||||||
@@ -31,15 +27,6 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--output-format, github, --fix, --line-length=120, --config, pyproject.toml]
|
args: [--output-format, github, --fix, --line-length=120, --config, pyproject.toml]
|
||||||
# For C++ files
|
|
||||||
- repo: local
|
|
||||||
hooks:
|
|
||||||
- id: clang-format
|
|
||||||
name: clang-format
|
|
||||||
description: Format files with ClangFormat.
|
|
||||||
entry: clang-format -i
|
|
||||||
language: system
|
|
||||||
files: \.(c|cc|cxx|cpp|cu|h|cuh|hpp|hxx|xpu|kps)$
|
|
||||||
# # 拼写检查
|
# # 拼写检查
|
||||||
# - repo: https://github.com/codespell-project/codespell
|
# - repo: https://github.com/codespell-project/codespell
|
||||||
# rev: v2.4.1
|
# rev: v2.4.1
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ English | [简体中文](README_CN.md)
|
|||||||
- 🤝 **OpenAI API Server and vLLM Compatible**: One-command deployment with [vLLM](https://github.com/vllm-project/vllm/) interface compatibility.
|
- 🤝 **OpenAI API Server and vLLM Compatible**: One-command deployment with [vLLM](https://github.com/vllm-project/vllm/) interface compatibility.
|
||||||
- 🧮 **Comprehensive Quantization Format Support**: W8A16, W8A8, W4A16, W4A8, W2A16, FP8, and more.
|
- 🧮 **Comprehensive Quantization Format Support**: W8A16, W8A8, W4A16, W4A8, W2A16, FP8, and more.
|
||||||
- ⏩ **Advanced Acceleration Techniques**: Speculative decoding, Multi-Token Prediction (MTP) and Chunked Prefill.
|
- ⏩ **Advanced Acceleration Techniques**: Speculative decoding, Multi-Token Prediction (MTP) and Chunked Prefill.
|
||||||
- 🖥️ **Multi-Hardware Support**: NVIDIA GPU, Kunlunxin XPU, Hygon DCU, Ascend NPU, Iluvatar GPU, Enflame GCU, MetaX GPU, Intel Gaudi etc.
|
- 🖥️ **Multi-Hardware Support**: NVIDIA GPU, Kunlunxin XPU, Hygon DCU, Ascend NPU, Iluvatar GPU, Enflame GCU, MetaX GPU etc.
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
@@ -59,8 +59,7 @@ FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**,
|
|||||||
- [Iluvatar GPU](./docs/get_started/installation/iluvatar_gpu.md)
|
- [Iluvatar GPU](./docs/get_started/installation/iluvatar_gpu.md)
|
||||||
- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md)
|
- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md)
|
||||||
- [Hygon DCU](./docs/get_started/installation/hygon_dcu.md)
|
- [Hygon DCU](./docs/get_started/installation/hygon_dcu.md)
|
||||||
- [MetaX GPU](./docs/get_started/installation/metax_gpu.md)
|
- [MetaX GPU](./docs/get_started/installation/metax_gpu.md.md)
|
||||||
- [Intel Gaudi](./docs/get_started/installation/intel_gaudi.md)
|
|
||||||
|
|
||||||
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU are currently under development and testing. Stay tuned for updates!
|
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU are currently under development and testing. Stay tuned for updates!
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,7 @@
|
|||||||
- 🤝 **OpenAI API服务与vLLM兼容**:单命令部署,兼容[vLLM](https://github.com/vllm-project/vllm/)接口
|
- 🤝 **OpenAI API服务与vLLM兼容**:单命令部署,兼容[vLLM](https://github.com/vllm-project/vllm/)接口
|
||||||
- 🧮 **全量化格式支持**:W8A16、W8A8、W4A16、W4A8、W2A16、FP8等
|
- 🧮 **全量化格式支持**:W8A16、W8A8、W4A16、W4A8、W2A16、FP8等
|
||||||
- ⏩ **高级加速技术**:推测解码、多令牌预测(MTP)及分块预填充
|
- ⏩ **高级加速技术**:推测解码、多令牌预测(MTP)及分块预填充
|
||||||
- 🖥️ **多硬件支持**:NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU、英特尔Gaudi等
|
- 🖥️ **多硬件支持**:NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU等
|
||||||
|
|
||||||
## 要求
|
## 要求
|
||||||
|
|
||||||
@@ -57,8 +57,7 @@ FastDeploy 支持在**英伟达(NVIDIA)GPU**、**昆仑芯(Kunlunxin)XPU
|
|||||||
- [天数 CoreX](./docs/zh/get_started/installation/iluvatar_gpu.md)
|
- [天数 CoreX](./docs/zh/get_started/installation/iluvatar_gpu.md)
|
||||||
- [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md)
|
- [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md)
|
||||||
- [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md)
|
- [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md)
|
||||||
- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md)
|
- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md.md)
|
||||||
- [英特尔 Gaudi](./docs/zh/get_started/installation/intel_gaudi.md)
|
|
||||||
|
|
||||||
**注意:** 我们正在积极拓展硬件支持范围。目前,包括昇腾(Ascend)NPU 等其他硬件平台正在开发测试中。敬请关注更新!
|
**注意:** 我们正在积极拓展硬件支持范围。目前,包括昇腾(Ascend)NPU 等其他硬件平台正在开发测试中。敬请关注更新!
|
||||||
|
|
||||||
|
|||||||
@@ -58,12 +58,10 @@ class RequestFuncOutput:
|
|||||||
"""Output for requesting LLMs via API"""
|
"""Output for requesting LLMs via API"""
|
||||||
|
|
||||||
no: int = 0
|
no: int = 0
|
||||||
request_id: str = ""
|
|
||||||
generated_text: str = ""
|
generated_text: str = ""
|
||||||
reasoning_content: str = ""
|
reasoning_content: str = ""
|
||||||
success: bool = False
|
success: bool = False
|
||||||
latency: float = 0.0
|
latency: float = 0.0
|
||||||
end_timestamp: float = 0.0 # 模型完全返回的时间戳(秒, perf_counter基准)
|
|
||||||
output_tokens: int = 0
|
output_tokens: int = 0
|
||||||
ttft: float = 0.0 # Time to first token
|
ttft: float = 0.0 # Time to first token
|
||||||
arrival_time: list = field(default_factory=list) # arrival_time
|
arrival_time: list = field(default_factory=list) # arrival_time
|
||||||
@@ -112,14 +110,12 @@ async def async_request_eb_openai_chat_completions(
|
|||||||
output = RequestFuncOutput()
|
output = RequestFuncOutput()
|
||||||
output.prompt_len = 0
|
output.prompt_len = 0
|
||||||
output.no = request_func_input.no
|
output.no = request_func_input.no
|
||||||
request_id = "None"
|
|
||||||
|
|
||||||
ttft = 0.0
|
ttft = 0.0
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
||||||
data = {}
|
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
async for chunk_bytes in response.content:
|
async for chunk_bytes in response.content:
|
||||||
chunk_bytes = chunk_bytes.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
@@ -128,13 +124,10 @@ async def async_request_eb_openai_chat_completions(
|
|||||||
|
|
||||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||||
if chunk != "[DONE]":
|
if chunk != "[DONE]":
|
||||||
#print("####chunk:", chunk, type(chunk))
|
# print("####chunk:", chunk, type(chunk))
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
|
|
||||||
if request_id == "None" and "id" in data:
|
|
||||||
request_id = data["id"]
|
|
||||||
|
|
||||||
if choices := data.get("choices"):
|
if choices := data.get("choices"):
|
||||||
content = choices[0]["delta"].get("content")
|
content = choices[0]["delta"].get("content")
|
||||||
reason_content = choices[0]["delta"].get("reasoning_content")
|
reason_content = choices[0]["delta"].get("reasoning_content")
|
||||||
@@ -143,12 +136,9 @@ async def async_request_eb_openai_chat_completions(
|
|||||||
ttft = timestamp - st
|
ttft = timestamp - st
|
||||||
output.ttft = ttft
|
output.ttft = ttft
|
||||||
# cached_tokens
|
# cached_tokens
|
||||||
if data["usage"] and data["usage"].get("prompt_tokens_details", {}):
|
output.prompt_len = (
|
||||||
output.prompt_len = (
|
data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
|
||||||
data["usage"].get("prompt_tokens_details", {}).get("cached_tokens", 0)
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
output.prompt_len = 0
|
|
||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
else:
|
else:
|
||||||
@@ -160,13 +150,10 @@ async def async_request_eb_openai_chat_completions(
|
|||||||
elif usage := data.get("usage", {}):
|
elif usage := data.get("usage", {}):
|
||||||
output.output_tokens = usage.get("completion_tokens", 0)
|
output.output_tokens = usage.get("completion_tokens", 0)
|
||||||
output.prompt_tokens = usage.get("prompt_tokens", 0)
|
output.prompt_tokens = usage.get("prompt_tokens", 0)
|
||||||
|
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
|
|
||||||
# output.generated_text = generated_text
|
# output.generated_text = generated_text
|
||||||
# 在流式结束时,记录最后一个 chunk 收到的时间戳
|
|
||||||
output.end_timestamp = most_recent_timestamp
|
|
||||||
if output.generated_text.strip() == "":
|
if output.generated_text.strip() == "":
|
||||||
output.success = False
|
output.success = False
|
||||||
output.error = "No generated text found!"
|
output.error = "No generated text found!"
|
||||||
@@ -188,8 +175,6 @@ async def async_request_eb_openai_chat_completions(
|
|||||||
exc_info = sys.exc_info()
|
exc_info = sys.exc_info()
|
||||||
output.error = "".join(traceback.format_exception(*exc_info))
|
output.error = "".join(traceback.format_exception(*exc_info))
|
||||||
|
|
||||||
output.request_id = request_id
|
|
||||||
|
|
||||||
# 保存失败请求结果
|
# 保存失败请求结果
|
||||||
if not output.success:
|
if not output.success:
|
||||||
with open("error_output.txt", "a") as f:
|
with open("error_output.txt", "a") as f:
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ def main(args):
|
|||||||
raise ValueError("--max_concurrency should be same length as --s_itl_base_model")
|
raise ValueError("--max_concurrency should be same length as --s_itl_base_model")
|
||||||
|
|
||||||
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
|
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
|
||||||
# Warmup
|
# Wramup
|
||||||
print("Starting warmup...")
|
print("Starting warmup...")
|
||||||
with open(os.devnull, "w") as f:
|
with open(os.devnull, "w") as f:
|
||||||
with contextlib.redirect_stdout(f):
|
with contextlib.redirect_stdout(f):
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ async def get_request(
|
|||||||
|
|
||||||
|
|
||||||
def calculate_metrics(
|
def calculate_metrics(
|
||||||
# input_requests: list[SampleRequest],
|
input_requests: list[SampleRequest],
|
||||||
outputs: list[RequestFuncOutput],
|
outputs: list[RequestFuncOutput],
|
||||||
dur_s: float,
|
dur_s: float,
|
||||||
selected_percentiles: list[float],
|
selected_percentiles: list[float],
|
||||||
@@ -177,7 +177,7 @@ def calculate_metrics(
|
|||||||
output_len = outputs[i].output_tokens
|
output_len = outputs[i].output_tokens
|
||||||
|
|
||||||
if not output_len:
|
if not output_len:
|
||||||
print("no output_len", outputs[i])
|
print("no output_len")
|
||||||
# We use the tokenizer to count the number of output tokens
|
# We use the tokenizer to count the number of output tokens
|
||||||
# for some serving backends instead of looking at
|
# for some serving backends instead of looking at
|
||||||
# len(outputs[i].itl) since multiple output tokens may be
|
# len(outputs[i].itl) since multiple output tokens may be
|
||||||
@@ -395,7 +395,6 @@ async def benchmark(
|
|||||||
print(f"Traffic request rate: {request_rate}")
|
print(f"Traffic request rate: {request_rate}")
|
||||||
print(f"Burstiness factor: {burstiness} ({distribution})")
|
print(f"Burstiness factor: {burstiness} ({distribution})")
|
||||||
print(f"Maximum request concurrency: {max_concurrency}")
|
print(f"Maximum request concurrency: {max_concurrency}")
|
||||||
print(f"Drop ratio: {args.drop_ratio}")
|
|
||||||
|
|
||||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||||
|
|
||||||
@@ -444,8 +443,6 @@ async def benchmark(
|
|||||||
tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar)))
|
tasks.append(asyncio.create_task(limited_request_func(request_func_input=request_func_input, pbar=pbar)))
|
||||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
outputs.sort(key=lambda x: x.end_timestamp)
|
|
||||||
|
|
||||||
if profile:
|
if profile:
|
||||||
print("Stopping profiler...")
|
print("Stopping profiler...")
|
||||||
profile_input = RequestFuncInput(
|
profile_input = RequestFuncInput(
|
||||||
@@ -463,35 +460,12 @@ async def benchmark(
|
|||||||
if pbar is not None:
|
if pbar is not None:
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
benchmark_outputs = outputs
|
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||||
drop_ratio = args.drop_ratio
|
print("benchmark_duration:", benchmark_duration)
|
||||||
if 0.0 < drop_ratio < 1:
|
|
||||||
# 按drop_ratio头尾各舍弃一半请求,不计入benchmark统计
|
|
||||||
n = len(outputs)
|
|
||||||
drop_count = int(n * drop_ratio)
|
|
||||||
half = drop_count // 2
|
|
||||||
if half > 0:
|
|
||||||
benchmark_outputs = outputs[half : n - half]
|
|
||||||
|
|
||||||
# 先过滤掉 end_timestamp == 0.0 的请求(失败请求)
|
|
||||||
benchmark_outputs = [o for o in benchmark_outputs if o.end_timestamp != 0.0]
|
|
||||||
|
|
||||||
# 根据收到最后一个chunk的时间戳计算总时长
|
|
||||||
if len(benchmark_outputs) >= 2:
|
|
||||||
benchmark_duration = benchmark_outputs[-1].end_timestamp - benchmark_outputs[0].end_timestamp
|
|
||||||
else:
|
|
||||||
benchmark_duration = 0.0
|
|
||||||
|
|
||||||
print(f"丢弃前数量: {n}")
|
|
||||||
print(f"丢弃后数量: {len(benchmark_outputs)}")
|
|
||||||
print(f"benchmark_duration: {benchmark_duration} 秒")
|
|
||||||
else:
|
|
||||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
|
||||||
print(f"benchmark_duration: {benchmark_duration} 秒")
|
|
||||||
|
|
||||||
metrics, actual_output_lens = calculate_metrics(
|
metrics, actual_output_lens = calculate_metrics(
|
||||||
# input_requests=input_requests,
|
input_requests=input_requests,
|
||||||
outputs=benchmark_outputs,
|
outputs=outputs,
|
||||||
dur_s=benchmark_duration,
|
dur_s=benchmark_duration,
|
||||||
# tokenizer=tokenizer,
|
# tokenizer=tokenizer,
|
||||||
selected_percentiles=selected_percentiles,
|
selected_percentiles=selected_percentiles,
|
||||||
@@ -520,7 +494,7 @@ async def benchmark(
|
|||||||
"total_token_throughput": metrics.total_token_throughput,
|
"total_token_throughput": metrics.total_token_throughput,
|
||||||
"input_lens": [output.prompt_len for output in outputs],
|
"input_lens": [output.prompt_len for output in outputs],
|
||||||
"infer_input_lens": [output.prompt_tokens for output in outputs],
|
"infer_input_lens": [output.prompt_tokens for output in outputs],
|
||||||
"output_lens": [output.output_tokens for output in outputs],
|
"output_lens": actual_output_lens,
|
||||||
"ttfts": [output.ttft for output in outputs],
|
"ttfts": [output.ttft for output in outputs],
|
||||||
"itls": [output.itl for output in outputs],
|
"itls": [output.itl for output in outputs],
|
||||||
"input_texts": [input.prompt for input in input_requests],
|
"input_texts": [input.prompt for input in input_requests],
|
||||||
@@ -635,7 +609,7 @@ def benchmark_metrics(
|
|||||||
goodput_config_dict = check_goodput_args(args)
|
goodput_config_dict = check_goodput_args(args)
|
||||||
|
|
||||||
metrics, actual_output_lens = calculate_metrics(
|
metrics, actual_output_lens = calculate_metrics(
|
||||||
# input_requests=input_requests,
|
input_requests=input_requests,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
dur_s=benchmark_duration,
|
dur_s=benchmark_duration,
|
||||||
selected_percentiles=selected_percentiles,
|
selected_percentiles=selected_percentiles,
|
||||||
@@ -991,7 +965,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
type=str,
|
type=str,
|
||||||
default="openai-chat",
|
default="vllm",
|
||||||
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -1107,12 +1081,6 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="shuffle dataset",
|
help="shuffle dataset",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--drop-ratio",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
help="Drop ratio of the outputs. [0, 1)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--trust-remote-code",
|
"--trust-remote-code",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
max_model_len: 32768
|
|
||||||
max_num_seqs: 128
|
|
||||||
tensor_parallel_size: 4
|
|
||||||
use_cudagraph: True
|
|
||||||
load_choices: "default_v1"
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
max_model_len: 32768
|
|
||||||
max_num_seqs: 128
|
|
||||||
tensor_parallel_size: 4
|
|
||||||
use_cudagraph: True
|
|
||||||
load_choices: "default_v1"
|
|
||||||
quantization: wfp8afp8
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
quantization: wint4
|
|
||||||
load_choices: "default_v1"
|
|
||||||
graph_optimization_config:
|
|
||||||
use_cudagraph: True
|
|
||||||
use_unique_memory_pool: True
|
|
||||||
enable_prefix_caching: False
|
|
||||||
max_num_seqs: 256
|
|
||||||
max_model_len: 32768
|
|
||||||
tensor_parallel_size: 8
|
|
||||||
@@ -6,4 +6,3 @@ tensor_parallel_size: 8
|
|||||||
max_num_batched_tokens: 4096
|
max_num_batched_tokens: 4096
|
||||||
max_num_partial_prefills: 3
|
max_num_partial_prefills: 3
|
||||||
max_long_partial_prefills: 3
|
max_long_partial_prefills: 3
|
||||||
quantization: wint4
|
|
||||||
|
|||||||
@@ -1,6 +0,0 @@
|
|||||||
tensor_parallel_size: 1
|
|
||||||
max_model_len: 131072
|
|
||||||
max_num_seqs: 32
|
|
||||||
quantization: wint4
|
|
||||||
max_num_batched_tokens: 8192
|
|
||||||
plas_attention_config: '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}'
|
|
||||||
@@ -6,4 +6,3 @@ tensor_parallel_size: 8
|
|||||||
max_num_batched_tokens: 4096
|
max_num_batched_tokens: 4096
|
||||||
max_num_partial_prefills: 3
|
max_num_partial_prefills: 3
|
||||||
max_long_partial_prefills: 3
|
max_long_partial_prefills: 3
|
||||||
quantization: wint8
|
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
max_model_len: 32768
|
|
||||||
max_num_seqs: 256
|
|
||||||
kv_cache_ratio: 0.75
|
|
||||||
tensor_parallel_size: 4
|
|
||||||
gpu_memory_utilization: 0.9
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
max_model_len: 32768
|
max_model_len: 32768
|
||||||
max_num_seqs: 96
|
max_num_seqs: 96
|
||||||
gpu_memory_utilization: 0.85
|
gpu_memory_utilization: 0.9
|
||||||
kv_cache_ratio: 0.71
|
kv_cache_ratio: 0.71
|
||||||
tensor_parallel_size: 4
|
tensor_parallel_size: 4
|
||||||
quantization: wint4
|
quantization: wint4
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
max_model_len: 32768
|
max_model_len: 32768
|
||||||
max_num_seqs: 96
|
max_num_seqs: 96
|
||||||
gpu_memory_utilization: 0.85
|
gpu_memory_utilization: 0.9
|
||||||
kv_cache_ratio: 0.71
|
kv_cache_ratio: 0.71
|
||||||
tensor_parallel_size: 4
|
tensor_parallel_size: 4
|
||||||
quantization: wint4
|
quantization: wint4
|
||||||
|
|||||||
@@ -13,4 +13,3 @@ pd_comm_port: "2334"
|
|||||||
max_num_batched_tokens: 384
|
max_num_batched_tokens: 384
|
||||||
max_num_partial_prefills: 3
|
max_num_partial_prefills: 3
|
||||||
max_long_partial_prefills: 3
|
max_long_partial_prefills: 3
|
||||||
quantization: wint4
|
|
||||||
|
|||||||
@@ -10,4 +10,3 @@ engine_worker_queue_port: 6677
|
|||||||
cache_transfer_protocol: "rdma,ipc"
|
cache_transfer_protocol: "rdma,ipc"
|
||||||
rdma_comm_ports: "7675,7676,7677,7678"
|
rdma_comm_ports: "7675,7676,7677,7678"
|
||||||
pd_comm_port: "2333"
|
pd_comm_port: "2333"
|
||||||
quantization: wint4
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
max_model_len: 32768
|
max_model_len: 32768
|
||||||
max_num_seqs: 96
|
max_num_seqs: 96
|
||||||
gpu_memory_utilization: 0.85
|
gpu_memory_utilization: 0.9
|
||||||
kv_cache_ratio: 0.71
|
kv_cache_ratio: 0.71
|
||||||
tensor_parallel_size: 8
|
tensor_parallel_size: 8
|
||||||
quantization: wint8
|
quantization: wint8
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
enable_mm: True
|
|
||||||
max_model_len: 131072
|
|
||||||
max_num_seqs: 56
|
|
||||||
gpu_memory_utilization: 0.8
|
|
||||||
kv_cache_ratio: 0.8
|
|
||||||
tensor_parallel_size: 8
|
|
||||||
quantization: wint4
|
|
||||||
limit_mm_per_prompt: '{"image": 100, "video": 100}'
|
|
||||||
enable_chunked_prefill: True
|
|
||||||
max_num_batched_tokens: 384
|
|
||||||
reasoning_parser: ernie-45-vl
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
enable_mm: True
|
enable_mm: True
|
||||||
max_model_len: 32768
|
max_model_len: 32768
|
||||||
max_num_seqs: 36
|
max_num_seqs: 36
|
||||||
gpu_memory_utilization: 0.9
|
gpu_memory_utilization: 0.95
|
||||||
kv_cache_ratio: 0.8
|
kv_cache_ratio: 0.8
|
||||||
tensor_parallel_size: 8
|
tensor_parallel_size: 8
|
||||||
quantization: wint8
|
quantization: wint8
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
enable_mm: True
|
enable_mm: True
|
||||||
max_model_len: 32768
|
max_model_len: 32768
|
||||||
max_num_seqs: 36
|
max_num_seqs: 36
|
||||||
gpu_memory_utilization: 0.85
|
gpu_memory_utilization: 0.8
|
||||||
kv_cache_ratio: 0.8
|
kv_cache_ratio: 0.8
|
||||||
tensor_parallel_size: 8
|
tensor_parallel_size: 8
|
||||||
quantization: wint8
|
quantization: wint8
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
enable_mm: True
|
|
||||||
max_model_len: 32768
|
|
||||||
max_num_seqs: 128
|
|
||||||
gpu_memory_utilization: 0.9
|
|
||||||
kv_cache_ratio: 0.71
|
|
||||||
tensor_parallel_size: 1
|
|
||||||
enable_chunked_prefill: True
|
|
||||||
max_num_batched_tokens: 384
|
|
||||||
reasoning_parser: ernie-45-vl
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
enable_mm: True
|
|
||||||
max_model_len: 32768
|
|
||||||
max_num_seqs: 128
|
|
||||||
gpu_memory_utilization: 0.9
|
|
||||||
kv_cache_ratio: 0.71
|
|
||||||
tensor_parallel_size: 1
|
|
||||||
enable_chunked_prefill: True
|
|
||||||
max_num_batched_tokens: 384
|
|
||||||
quantization: wint4
|
|
||||||
reasoning_parser: ernie-45-vl
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
enable_mm: True
|
|
||||||
max_model_len: 32768
|
|
||||||
max_num_seqs: 128
|
|
||||||
gpu_memory_utilization: 0.9
|
|
||||||
kv_cache_ratio: 0.71
|
|
||||||
tensor_parallel_size: 1
|
|
||||||
enable_chunked_prefill: True
|
|
||||||
max_num_batched_tokens: 384
|
|
||||||
quantization: wint8
|
|
||||||
reasoning_parser: ernie-45-vl
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
max_tokens: 131071
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
max_tokens: 12288
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
top_p: 0.8
|
|
||||||
temperature: 0.8
|
|
||||||
max_tokens: 12288
|
|
||||||
repetition_penalty: 1.0
|
|
||||||
frequency_penalty: 0
|
|
||||||
presence_penalty: 0
|
|
||||||
metadata:
|
|
||||||
enable_thinking: false
|
|
||||||
min_tokens: 1
|
|
||||||
chat_template_kwargs:
|
|
||||||
enable_thinking: false
|
|
||||||
@@ -1,10 +1,11 @@
|
|||||||
temperature: 0.8
|
top_p: 1.0
|
||||||
top_p: 0.8
|
temperature: 1.0
|
||||||
presence_penalty: 0
|
|
||||||
repetition_penalty: 1.0
|
|
||||||
frequency_penalty: 0
|
|
||||||
max_tokens: 12288
|
|
||||||
metadata:
|
metadata:
|
||||||
min_tokens: 1
|
min_tokens: 1
|
||||||
|
max_tokens: 30721
|
||||||
|
repetition_penalty: 1.0
|
||||||
|
frequency_penalty: 0
|
||||||
|
presence_penalty: 0
|
||||||
|
skip_special_tokens: false
|
||||||
chat_template_kwargs:
|
chat_template_kwargs:
|
||||||
enable_thinking: false
|
enable_thinking: true
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
top_p: 0.95
|
|
||||||
temperature: 0.6
|
|
||||||
metadata:
|
|
||||||
min_tokens: 1
|
|
||||||
max_tokens: 131071
|
|
||||||
repetition_penalty: 1.0
|
|
||||||
frequency_penalty: 0
|
|
||||||
presence_penalty: 0
|
|
||||||
@@ -2,7 +2,7 @@ top_p: 0.95
|
|||||||
temperature: 0.6
|
temperature: 0.6
|
||||||
metadata:
|
metadata:
|
||||||
min_tokens: 1
|
min_tokens: 1
|
||||||
max_tokens: 12288
|
max_tokens: 65535
|
||||||
repetition_penalty: 1.0
|
repetition_penalty: 1.0
|
||||||
frequency_penalty: 0
|
frequency_penalty: 0
|
||||||
presence_penalty: 0
|
presence_penalty: 0
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
reasoning-parser: ernie-x1
|
reasoning-parser: ernie_x1
|
||||||
tool_call_parser: ernie-x1
|
tool_call_parser: ernie_x1
|
||||||
tensor_parallel_size: 4
|
tensor_parallel_size: 4
|
||||||
max_model_len: 65536
|
max_model_len: 65536
|
||||||
max_num_seqs: 128
|
max_num_seqs: 128
|
||||||
enable_prefix_caching: True
|
enable_prefix_caching: True
|
||||||
enable_chunked_prefill: True
|
enable_chunked_prefill: True
|
||||||
gpu_memory_utilization: 0.85
|
gpu_memory_utilization: 0.85
|
||||||
graph_optimization_config:
|
use_cudagraph: True
|
||||||
use_cudagraph: True
|
enable_custom_all_reduce: True
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
tensor_parallel_size: 1
|
|
||||||
max_model_len: 131072
|
|
||||||
max_num_seqs: 32
|
|
||||||
reasoning_parser: ernie-x1
|
|
||||||
tool_call_parser: ernie-x1
|
|
||||||
load_choices: "default_v1"
|
|
||||||
quantization: wint8
|
|
||||||
14
build.sh
14
build.sh
@@ -128,12 +128,6 @@ function copy_ops(){
|
|||||||
echo -e "MACA ops have been copy to fastdeploy"
|
echo -e "MACA ops have been copy to fastdeploy"
|
||||||
return
|
return
|
||||||
fi
|
fi
|
||||||
is_intel_hpu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('intel_hpu'))"`
|
|
||||||
if [ "$is_intel_hpu" = "True" ]; then
|
|
||||||
DEVICE_TYPE="intel-hpu"
|
|
||||||
echo -e "intel_hpu ops have been copy to fastdeploy"
|
|
||||||
return
|
|
||||||
fi
|
|
||||||
|
|
||||||
DEVICE_TYPE="cpu"
|
DEVICE_TYPE="cpu"
|
||||||
cd ../../../../
|
cd ../../../../
|
||||||
@@ -149,9 +143,9 @@ function build_and_install_ops() {
|
|||||||
TMP_DIR_REAL_PATH=`readlink -f ${OPS_TMP_DIR}`
|
TMP_DIR_REAL_PATH=`readlink -f ${OPS_TMP_DIR}`
|
||||||
is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"`
|
is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"`
|
||||||
if [ "$is_xpu" = "True" ]; then
|
if [ "$is_xpu" = "True" ]; then
|
||||||
cd xpu_ops
|
cd xpu_ops/src
|
||||||
bash build.sh ${TMP_DIR_REAL_PATH}
|
bash build.sh ${TMP_DIR_REAL_PATH}
|
||||||
cd ..
|
cd ../..
|
||||||
elif [ "$FD_CPU_USE_BF16" == "true" ]; then
|
elif [ "$FD_CPU_USE_BF16" == "true" ]; then
|
||||||
if [ "$FD_BUILDING_ARCS" == "" ]; then
|
if [ "$FD_BUILDING_ARCS" == "" ]; then
|
||||||
FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
|
FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
|
||||||
@@ -165,9 +159,7 @@ function build_and_install_ops() {
|
|||||||
else
|
else
|
||||||
FD_BUILDING_ARCS=${FD_BUILDING_ARCS} ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
|
FD_BUILDING_ARCS=${FD_BUILDING_ARCS} ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
|
||||||
fi
|
fi
|
||||||
if [ -d "${OPS_TMP_DIR}" ]; then
|
find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \;
|
||||||
find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \;
|
|
||||||
fi
|
|
||||||
else
|
else
|
||||||
echo "Error: Invalid parameter '$FD_CPU_USE_BF16'. Please use true or false."
|
echo "Error: Invalid parameter '$FD_CPU_USE_BF16'. Please use true or false."
|
||||||
exit 1
|
exit 1
|
||||||
|
|||||||
@@ -19,28 +19,28 @@ std::vector<paddle::Tensor> InvokeAvxWeightOnly(const paddle::Tensor &x,
|
|||||||
const paddle::Tensor &w_bias,
|
const paddle::Tensor &w_bias,
|
||||||
const std::string &alog,
|
const std::string &alog,
|
||||||
bool trans) {
|
bool trans) {
|
||||||
auto out_shape = x.shape();
|
auto out_shape = x.shape();
|
||||||
out_shape[out_shape.size() - 1] = weight.shape()[1];
|
out_shape[out_shape.size() - 1] = weight.shape()[1];
|
||||||
auto out = paddle::empty(out_shape, x.dtype(), paddle::CPUPlace());
|
auto out = paddle::empty(out_shape, x.dtype(), paddle::CPUPlace());
|
||||||
return {out};
|
return {out};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> AvxWeightOnlyInferShape(
|
std::vector<std::vector<int64_t>> AvxWeightOnlyInferShape(
|
||||||
std::vector<int64_t> x_shape,
|
std::vector<int64_t> x_shape,
|
||||||
std::vector<int64_t> weigh_shape,
|
std::vector<int64_t> weigh_shape,
|
||||||
std::vector<int64_t> weigh_bias_shape) {
|
std::vector<int64_t> weigh_bias_shape) {
|
||||||
int m = 1;
|
int m = 1;
|
||||||
for (int i = 0; i < x_shape.size() - 1; i++) {
|
for (int i = 0; i < x_shape.size() - 1; i++) {
|
||||||
m = m * x_shape[i];
|
m = m * x_shape[i];
|
||||||
}
|
}
|
||||||
return {std::vector<int64_t>{m, weigh_shape[1]}};
|
return {std::vector<int64_t>{m, weigh_shape[1]}};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::DataType> AvxWeightOnlyInferDtype(
|
std::vector<paddle::DataType> AvxWeightOnlyInferDtype(
|
||||||
paddle::DataType x_dtype,
|
paddle::DataType x_dtype,
|
||||||
paddle::DataType weight_dtype,
|
paddle::DataType weight_dtype,
|
||||||
paddle::DataType weight_bias_dtype) {
|
paddle::DataType weight_bias_dtype) {
|
||||||
return {x_dtype};
|
return {x_dtype};
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(avx_weight_only)
|
PD_BUILD_STATIC_OP(avx_weight_only)
|
||||||
|
|||||||
@@ -20,13 +20,13 @@ void remove_padding(int64_t *output_data,
|
|||||||
const int *cum_offsets,
|
const int *cum_offsets,
|
||||||
const int sequence_length,
|
const int sequence_length,
|
||||||
const int bsz) {
|
const int bsz) {
|
||||||
for (int bi = 0; bi < bsz; ++bi) {
|
for (int bi = 0; bi < bsz; ++bi) {
|
||||||
for (int i = 0; i < seq_lens[bi]; ++i) {
|
for (int i = 0; i < seq_lens[bi]; ++i) {
|
||||||
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
|
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
|
||||||
const int src_seq_id = bi * sequence_length + i;
|
const int src_seq_id = bi * sequence_length + i;
|
||||||
output_data[tgt_seq_id] = input_data[src_seq_id];
|
output_data[tgt_seq_id] = input_data[src_seq_id];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void get_padding_offset_kernel(int *padding_offset,
|
void get_padding_offset_kernel(int *padding_offset,
|
||||||
@@ -37,53 +37,56 @@ void get_padding_offset_kernel(int *padding_offset,
|
|||||||
const int *seq_lens,
|
const int *seq_lens,
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
const int bsz) {
|
const int bsz) {
|
||||||
for (int bi = 0; bi < bsz; ++bi) {
|
for (int bi = 0; bi < bsz; ++bi) {
|
||||||
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
|
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
|
||||||
auto seq_len_now = seq_lens[bi];
|
auto seq_len_now = seq_lens[bi];
|
||||||
for (int i = 0; i < seq_len_now; ++i) {
|
for (int i = 0; i < seq_len_now; ++i) {
|
||||||
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
|
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
|
||||||
|
}
|
||||||
|
cum_offsets_out[bi] = cum_offset;
|
||||||
|
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
|
||||||
|
cu_seqlens_q[bi + 1] = cum_seq_len;
|
||||||
|
cu_seqlens_k[bi + 1] = cum_seq_len;
|
||||||
}
|
}
|
||||||
cum_offsets_out[bi] = cum_offset;
|
|
||||||
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
|
|
||||||
cu_seqlens_q[bi + 1] = cum_seq_len;
|
|
||||||
cu_seqlens_k[bi + 1] = cum_seq_len;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||||
const paddle::Tensor &cum_offsets,
|
const paddle::Tensor &cum_offsets,
|
||||||
const paddle::Tensor &token_num,
|
const paddle::Tensor &token_num,
|
||||||
const paddle::Tensor &seq_len) {
|
const paddle::Tensor &seq_len) {
|
||||||
std::vector<int64_t> input_ids_shape = input_ids.shape();
|
std::vector<int64_t> input_ids_shape = input_ids.shape();
|
||||||
const int bsz = seq_len.shape()[0];
|
const int bsz = seq_len.shape()[0];
|
||||||
const int seq_length = input_ids_shape[1];
|
const int seq_length = input_ids_shape[1];
|
||||||
auto cum_offsets_out = cum_offsets.copy_to(paddle::CPUPlace(), false);
|
auto cum_offsets_out = cum_offsets.copy_to(paddle::CPUPlace(), false);
|
||||||
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
|
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);
|
||||||
|
|
||||||
const int token_num_data = cpu_token_num.data<int64_t>()[0];
|
const int token_num_data = cpu_token_num.data<int64_t>()[0];
|
||||||
auto x_remove_padding = paddle::empty(
|
auto x_remove_padding = paddle::empty(
|
||||||
{token_num_data}, paddle::DataType::INT64, input_ids.place());
|
{token_num_data}, paddle::DataType::INT64, input_ids.place());
|
||||||
auto padding_offset = paddle::empty(
|
auto padding_offset = paddle::empty(
|
||||||
{token_num_data}, paddle::DataType::INT32, input_ids.place());
|
{token_num_data}, paddle::DataType::INT32, input_ids.place());
|
||||||
auto cu_seqlens_q =
|
auto cu_seqlens_q =
|
||||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||||
auto cu_seqlens_k =
|
auto cu_seqlens_k =
|
||||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||||
get_padding_offset_kernel(padding_offset.data<int>(),
|
get_padding_offset_kernel(padding_offset.data<int>(),
|
||||||
cum_offsets_out.data<int>(),
|
cum_offsets_out.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
cu_seqlens_k.data<int>(),
|
cu_seqlens_k.data<int>(),
|
||||||
cum_offsets.data<int>(),
|
cum_offsets.data<int>(),
|
||||||
seq_len.data<int>(),
|
seq_len.data<int>(),
|
||||||
seq_length,
|
seq_length,
|
||||||
bsz);
|
bsz);
|
||||||
remove_padding(x_remove_padding.data<int64_t>(),
|
remove_padding(x_remove_padding.data<int64_t>(),
|
||||||
input_ids.data<int64_t>(),
|
input_ids.data<int64_t>(),
|
||||||
seq_len.data<int>(),
|
seq_len.data<int>(),
|
||||||
cum_offsets_out.data<int>(),
|
cum_offsets_out.data<int>(),
|
||||||
seq_length,
|
seq_length,
|
||||||
bsz);
|
bsz);
|
||||||
return {x_remove_padding, padding_offset, cu_seqlens_q, cu_seqlens_k};
|
return {x_remove_padding,
|
||||||
|
padding_offset,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
|
std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
|
||||||
@@ -91,9 +94,9 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
|
|||||||
const std::vector<int64_t> &cum_offsets_shape,
|
const std::vector<int64_t> &cum_offsets_shape,
|
||||||
const std::vector<int64_t> &token_num_shape,
|
const std::vector<int64_t> &token_num_shape,
|
||||||
const std::vector<int64_t> &seq_len_shape) {
|
const std::vector<int64_t> &seq_len_shape) {
|
||||||
int64_t bsz = seq_len_shape[0];
|
int64_t bsz = seq_len_shape[0];
|
||||||
int64_t seq_len = input_ids_shape[1];
|
int64_t seq_len = input_ids_shape[1];
|
||||||
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
|
return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||||
@@ -101,13 +104,18 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
|||||||
const paddle::DataType &cum_offsets_dtype,
|
const paddle::DataType &cum_offsets_dtype,
|
||||||
const paddle::DataType &token_num_dtype,
|
const paddle::DataType &token_num_dtype,
|
||||||
const paddle::DataType &seq_len_dtype) {
|
const paddle::DataType &seq_len_dtype) {
|
||||||
return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype};
|
return {input_ids_dtype,
|
||||||
|
seq_len_dtype,
|
||||||
|
seq_len_dtype,
|
||||||
|
seq_len_dtype};
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(get_padding_offset_cpu)
|
PD_BUILD_STATIC_OP(get_padding_offset_cpu)
|
||||||
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
|
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
|
||||||
.Outputs(
|
.Outputs({"x_remove_padding",
|
||||||
{"x_remove_padding", "padding_offset", "cu_seqlens_q", "cu_seqlens_k"})
|
"padding_offset",
|
||||||
|
"cu_seqlens_q",
|
||||||
|
"cu_seqlens_k"})
|
||||||
.SetKernelFn(PD_KERNEL(GetPaddingOffset))
|
.SetKernelFn(PD_KERNEL(GetPaddingOffset))
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetInferShape))
|
.SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetInferShape))
|
||||||
.SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetInferDtype));
|
.SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetInferDtype));
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void RebuildPaddingCPUImpl(T *output_data,
|
void RebuildPaddingCPUImpl(T *output_data,
|
||||||
const T *input_data,
|
const T *input_data,
|
||||||
@@ -29,27 +30,27 @@ void RebuildPaddingCPUImpl(T *output_data,
|
|||||||
int max_input_length,
|
int max_input_length,
|
||||||
int dim_embed,
|
int dim_embed,
|
||||||
const int elem_nums) {
|
const int elem_nums) {
|
||||||
for (int i = 0; i < elem_nums; ++i) {
|
for (int i = 0; i < elem_nums; ++i) {
|
||||||
const int bi = i / dim_embed;
|
const int bi = i / dim_embed;
|
||||||
const int bias_idx = i % dim_embed;
|
const int bias_idx = i % dim_embed;
|
||||||
int seq_id = 0;
|
int seq_id = 0;
|
||||||
|
|
||||||
if (seq_len_this_time_data[bi] == 0) {
|
if (seq_len_this_time_data[bi] == 0) {
|
||||||
continue;
|
continue;
|
||||||
|
}
|
||||||
|
if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (seq_lens_encoder_data[bi] > 0) {
|
||||||
|
seq_id = seq_lens_encoder_data[bi] - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ori_token_idx = cu_seqlens_q_data[bi] + seq_id;
|
||||||
|
const int src_offset = ori_token_idx * dim_embed + bias_idx;
|
||||||
|
|
||||||
|
output_data[i] = input_data[src_offset];
|
||||||
}
|
}
|
||||||
if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (seq_lens_encoder_data[bi] > 0) {
|
|
||||||
seq_id = seq_lens_encoder_data[bi] - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int ori_token_idx = cu_seqlens_q_data[bi] + seq_id;
|
|
||||||
const int src_offset = ori_token_idx * dim_embed + bias_idx;
|
|
||||||
|
|
||||||
output_data[i] = input_data[src_offset];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@@ -63,25 +64,27 @@ void RebuildAppendPaddingCPUImpl(T *output_data,
|
|||||||
const int max_input_length,
|
const int max_input_length,
|
||||||
const int dim_embed,
|
const int dim_embed,
|
||||||
const int64_t output_elem_nums) {
|
const int64_t output_elem_nums) {
|
||||||
for (int i = 0; i < output_elem_nums; ++i) {
|
for (int i = 0; i < output_elem_nums; ++i) {
|
||||||
int out_token_id = i / dim_embed;
|
int out_token_id = i / dim_embed;
|
||||||
int ori_token_id = out_token_id + output_padding_offset_data[out_token_id];
|
int ori_token_id =
|
||||||
int bi = ori_token_id / max_input_length;
|
out_token_id + output_padding_offset_data[out_token_id];
|
||||||
if (seq_len_this_time_data[bi] == 0 ||
|
int bi = ori_token_id / max_input_length;
|
||||||
(seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0)) {
|
if (seq_len_this_time_data[bi] == 0 ||
|
||||||
continue;
|
(seq_lens_decoder_data[bi] == 0 &&
|
||||||
}
|
seq_lens_encoder_data[bi] == 0)) {
|
||||||
int seq_id = 0;
|
continue;
|
||||||
|
}
|
||||||
|
int seq_id = 0;
|
||||||
|
|
||||||
if (seq_lens_encoder_data[bi] > 0) {
|
if (seq_lens_encoder_data[bi] > 0) {
|
||||||
seq_id = seq_lens_encoder_data[bi] - 1;
|
seq_id = seq_lens_encoder_data[bi] - 1;
|
||||||
}
|
}
|
||||||
int input_token_id = cu_seqlens_q_data[bi] + seq_id;
|
int input_token_id = cu_seqlens_q_data[bi] + seq_id;
|
||||||
int bias_idx = i % dim_embed;
|
int bias_idx = i % dim_embed;
|
||||||
int src_offset = input_token_id * dim_embed + bias_idx;
|
int src_offset = input_token_id * dim_embed + bias_idx;
|
||||||
|
|
||||||
output_data[i] = input_data[src_offset];
|
output_data[i] = input_data[src_offset];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::Tensor> RebuildPaddingCPU(
|
std::vector<paddle::Tensor> RebuildPaddingCPU(
|
||||||
@@ -92,139 +95,140 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
|
|||||||
const paddle::Tensor &seq_lens_encoder,
|
const paddle::Tensor &seq_lens_encoder,
|
||||||
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
const paddle::optional<paddle::Tensor> &output_padding_offset,
|
||||||
int max_input_length) {
|
int max_input_length) {
|
||||||
auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true);
|
auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true);
|
||||||
auto cu_seqlens_q_cpu = cu_seqlens_q.copy_to(paddle::CPUPlace(), true);
|
auto cu_seqlens_q_cpu = cu_seqlens_q.copy_to(paddle::CPUPlace(), true);
|
||||||
auto seq_len_this_time_cpu =
|
auto seq_len_this_time_cpu =
|
||||||
seq_len_this_time.copy_to(paddle::CPUPlace(), true);
|
seq_len_this_time.copy_to(paddle::CPUPlace(), true);
|
||||||
auto seq_lens_decoder_cpu =
|
auto seq_lens_decoder_cpu =
|
||||||
seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
|
seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
|
||||||
auto seq_lens_encoder_cpu =
|
auto seq_lens_encoder_cpu =
|
||||||
seq_lens_encoder.copy_to(paddle::CPUPlace(), true);
|
seq_lens_encoder.copy_to(paddle::CPUPlace(), true);
|
||||||
paddle::optional<paddle::Tensor> output_padding_offset_cpu;
|
paddle::optional<paddle::Tensor> output_padding_offset_cpu;
|
||||||
if (output_padding_offset) {
|
if (output_padding_offset) {
|
||||||
output_padding_offset_cpu =
|
output_padding_offset_cpu =
|
||||||
output_padding_offset->copy_to(paddle::CPUPlace(), true);
|
output_padding_offset->copy_to(paddle::CPUPlace(), true);
|
||||||
}
|
|
||||||
|
|
||||||
int token_num = tmp_out_cpu.shape()[0];
|
|
||||||
int dim_embed = tmp_out_cpu.shape()[1];
|
|
||||||
int bsz = cu_seqlens_q_cpu.shape()[0] - 1;
|
|
||||||
|
|
||||||
paddle::Tensor out;
|
|
||||||
if (output_padding_offset_cpu) {
|
|
||||||
int need_delete_token_num = 0;
|
|
||||||
for (int i = 0; i < bsz; ++i) {
|
|
||||||
if (seq_lens_encoder_cpu.data<int>()[i] > 0) {
|
|
||||||
need_delete_token_num += seq_lens_encoder_cpu.data<int>()[i] - 1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
int output_token_num = token_num - need_delete_token_num;
|
|
||||||
out = paddle::full({output_token_num, dim_embed},
|
|
||||||
0,
|
|
||||||
tmp_out_cpu.dtype(),
|
|
||||||
paddle::CPUPlace());
|
|
||||||
} else {
|
|
||||||
out = paddle::full(
|
|
||||||
{bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace());
|
|
||||||
}
|
|
||||||
|
|
||||||
const int *cu_seqlens_q_data = cu_seqlens_q_cpu.data<int>();
|
int token_num = tmp_out_cpu.shape()[0];
|
||||||
const int *seq_len_this_time_data = seq_len_this_time_cpu.data<int>();
|
int dim_embed = tmp_out_cpu.shape()[1];
|
||||||
const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
|
int bsz = cu_seqlens_q_cpu.shape()[0] - 1;
|
||||||
const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data<int>();
|
|
||||||
int elem_nums = out.numel();
|
|
||||||
|
|
||||||
if (output_padding_offset_cpu) {
|
paddle::Tensor out;
|
||||||
const int *output_padding_offset_data =
|
if (output_padding_offset_cpu) {
|
||||||
output_padding_offset_cpu->data<int>();
|
int need_delete_token_num = 0;
|
||||||
switch (tmp_out_cpu.dtype()) {
|
for (int i = 0; i < bsz; ++i) {
|
||||||
case paddle::DataType::FLOAT32:
|
if (seq_lens_encoder_cpu.data<int>()[i] > 0) {
|
||||||
RebuildAppendPaddingCPUImpl<float>(out.data<float>(),
|
need_delete_token_num +=
|
||||||
tmp_out_cpu.data<float>(),
|
seq_lens_encoder_cpu.data<int>()[i] - 1;
|
||||||
cu_seqlens_q_data,
|
}
|
||||||
seq_len_this_time_data,
|
}
|
||||||
seq_lens_decoder_data,
|
int output_token_num = token_num - need_delete_token_num;
|
||||||
seq_lens_encoder_data,
|
out = paddle::full({output_token_num, dim_embed},
|
||||||
output_padding_offset_data,
|
0,
|
||||||
max_input_length,
|
tmp_out_cpu.dtype(),
|
||||||
dim_embed,
|
paddle::CPUPlace());
|
||||||
elem_nums);
|
} else {
|
||||||
break;
|
out = paddle::full(
|
||||||
case paddle::DataType::FLOAT16:
|
{bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace());
|
||||||
RebuildAppendPaddingCPUImpl<paddle::float16>(
|
|
||||||
out.data<paddle::float16>(),
|
|
||||||
tmp_out_cpu.data<paddle::float16>(),
|
|
||||||
cu_seqlens_q_data,
|
|
||||||
seq_len_this_time_data,
|
|
||||||
seq_lens_decoder_data,
|
|
||||||
seq_lens_encoder_data,
|
|
||||||
output_padding_offset_data,
|
|
||||||
max_input_length,
|
|
||||||
dim_embed,
|
|
||||||
elem_nums);
|
|
||||||
break;
|
|
||||||
case paddle::DataType::BFLOAT16:
|
|
||||||
RebuildAppendPaddingCPUImpl<paddle::bfloat16>(
|
|
||||||
out.data<paddle::bfloat16>(),
|
|
||||||
tmp_out_cpu.data<paddle::bfloat16>(),
|
|
||||||
cu_seqlens_q_data,
|
|
||||||
seq_len_this_time_data,
|
|
||||||
seq_lens_decoder_data,
|
|
||||||
seq_lens_encoder_data,
|
|
||||||
output_padding_offset_data,
|
|
||||||
max_input_length,
|
|
||||||
dim_embed,
|
|
||||||
elem_nums);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
PD_THROW(
|
|
||||||
"Unsupported data type for rebuild_padding_cpu. "
|
|
||||||
"Only float32, float16, and bfloat16 are supported.");
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
switch (tmp_out_cpu.dtype()) {
|
const int *cu_seqlens_q_data = cu_seqlens_q_cpu.data<int>();
|
||||||
case paddle::DataType::FLOAT32:
|
const int *seq_len_this_time_data = seq_len_this_time_cpu.data<int>();
|
||||||
RebuildPaddingCPUImpl<float>(out.data<float>(),
|
const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
|
||||||
tmp_out_cpu.data<float>(),
|
const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data<int>();
|
||||||
cu_seqlens_q_data,
|
int elem_nums = out.numel();
|
||||||
seq_len_this_time_data,
|
|
||||||
seq_lens_decoder_data,
|
if (output_padding_offset_cpu) {
|
||||||
seq_lens_encoder_data,
|
const int *output_padding_offset_data =
|
||||||
max_input_length,
|
output_padding_offset_cpu->data<int>();
|
||||||
dim_embed,
|
switch (tmp_out_cpu.dtype()) {
|
||||||
elem_nums);
|
case paddle::DataType::FLOAT32:
|
||||||
break;
|
RebuildAppendPaddingCPUImpl<float>(out.data<float>(),
|
||||||
case paddle::DataType::FLOAT16:
|
tmp_out_cpu.data<float>(),
|
||||||
RebuildPaddingCPUImpl<paddle::float16>(
|
cu_seqlens_q_data,
|
||||||
out.data<paddle::float16>(),
|
seq_len_this_time_data,
|
||||||
tmp_out_cpu.data<paddle::float16>(),
|
seq_lens_decoder_data,
|
||||||
cu_seqlens_q_data,
|
seq_lens_encoder_data,
|
||||||
seq_len_this_time_data,
|
output_padding_offset_data,
|
||||||
seq_lens_decoder_data,
|
max_input_length,
|
||||||
seq_lens_encoder_data,
|
dim_embed,
|
||||||
max_input_length,
|
elem_nums);
|
||||||
dim_embed,
|
break;
|
||||||
elem_nums);
|
case paddle::DataType::FLOAT16:
|
||||||
break;
|
RebuildAppendPaddingCPUImpl<paddle::float16>(
|
||||||
case paddle::DataType::BFLOAT16:
|
out.data<paddle::float16>(),
|
||||||
RebuildPaddingCPUImpl<paddle::bfloat16>(
|
tmp_out_cpu.data<paddle::float16>(),
|
||||||
out.data<paddle::bfloat16>(),
|
cu_seqlens_q_data,
|
||||||
tmp_out_cpu.data<paddle::bfloat16>(),
|
seq_len_this_time_data,
|
||||||
cu_seqlens_q_data,
|
seq_lens_decoder_data,
|
||||||
seq_len_this_time_data,
|
seq_lens_encoder_data,
|
||||||
seq_lens_decoder_data,
|
output_padding_offset_data,
|
||||||
seq_lens_encoder_data,
|
max_input_length,
|
||||||
max_input_length,
|
dim_embed,
|
||||||
dim_embed,
|
elem_nums);
|
||||||
elem_nums);
|
break;
|
||||||
break;
|
case paddle::DataType::BFLOAT16:
|
||||||
default:
|
RebuildAppendPaddingCPUImpl<paddle::bfloat16>(
|
||||||
PD_THROW(
|
out.data<paddle::bfloat16>(),
|
||||||
"Unsupported data type for rebuild_padding_cpu. "
|
tmp_out_cpu.data<paddle::bfloat16>(),
|
||||||
"Only float32, float16, and bfloat16 are supported.");
|
cu_seqlens_q_data,
|
||||||
|
seq_len_this_time_data,
|
||||||
|
seq_lens_decoder_data,
|
||||||
|
seq_lens_encoder_data,
|
||||||
|
output_padding_offset_data,
|
||||||
|
max_input_length,
|
||||||
|
dim_embed,
|
||||||
|
elem_nums);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
PD_THROW(
|
||||||
|
"Unsupported data type for rebuild_padding_cpu. "
|
||||||
|
"Only float32, float16, and bfloat16 are supported.");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
switch (tmp_out_cpu.dtype()) {
|
||||||
|
case paddle::DataType::FLOAT32:
|
||||||
|
RebuildPaddingCPUImpl<float>(out.data<float>(),
|
||||||
|
tmp_out_cpu.data<float>(),
|
||||||
|
cu_seqlens_q_data,
|
||||||
|
seq_len_this_time_data,
|
||||||
|
seq_lens_decoder_data,
|
||||||
|
seq_lens_encoder_data,
|
||||||
|
max_input_length,
|
||||||
|
dim_embed,
|
||||||
|
elem_nums);
|
||||||
|
break;
|
||||||
|
case paddle::DataType::FLOAT16:
|
||||||
|
RebuildPaddingCPUImpl<paddle::float16>(
|
||||||
|
out.data<paddle::float16>(),
|
||||||
|
tmp_out_cpu.data<paddle::float16>(),
|
||||||
|
cu_seqlens_q_data,
|
||||||
|
seq_len_this_time_data,
|
||||||
|
seq_lens_decoder_data,
|
||||||
|
seq_lens_encoder_data,
|
||||||
|
max_input_length,
|
||||||
|
dim_embed,
|
||||||
|
elem_nums);
|
||||||
|
break;
|
||||||
|
case paddle::DataType::BFLOAT16:
|
||||||
|
RebuildPaddingCPUImpl<paddle::bfloat16>(
|
||||||
|
out.data<paddle::bfloat16>(),
|
||||||
|
tmp_out_cpu.data<paddle::bfloat16>(),
|
||||||
|
cu_seqlens_q_data,
|
||||||
|
seq_len_this_time_data,
|
||||||
|
seq_lens_decoder_data,
|
||||||
|
seq_lens_encoder_data,
|
||||||
|
max_input_length,
|
||||||
|
dim_embed,
|
||||||
|
elem_nums);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
PD_THROW(
|
||||||
|
"Unsupported data type for rebuild_padding_cpu. "
|
||||||
|
"Only float32, float16, and bfloat16 are supported.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
return {out};
|
||||||
return {out};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
|
std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
|
||||||
@@ -234,13 +238,13 @@ std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
|
|||||||
const std::vector<int64_t> &seq_lens_decoder_shape,
|
const std::vector<int64_t> &seq_lens_decoder_shape,
|
||||||
const std::vector<int64_t> &seq_lens_encoder_shape,
|
const std::vector<int64_t> &seq_lens_encoder_shape,
|
||||||
const paddle::optional<std::vector<int64_t>> &output_padding_offset_shape) {
|
const paddle::optional<std::vector<int64_t>> &output_padding_offset_shape) {
|
||||||
int64_t dim_embed = tmp_out_shape[1];
|
int64_t dim_embed = tmp_out_shape[1];
|
||||||
if (output_padding_offset_shape) {
|
if (output_padding_offset_shape) {
|
||||||
return {{-1, dim_embed}};
|
return {{-1, dim_embed}};
|
||||||
} else {
|
} else {
|
||||||
int64_t bsz = cu_seqlens_q_shape[0] - 1;
|
int64_t bsz = cu_seqlens_q_shape[0] - 1;
|
||||||
return {{bsz, dim_embed}};
|
return {{bsz, dim_embed}};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::DataType> RebuildPaddingInferDtype(
|
std::vector<paddle::DataType> RebuildPaddingInferDtype(
|
||||||
@@ -250,7 +254,7 @@ std::vector<paddle::DataType> RebuildPaddingInferDtype(
|
|||||||
const paddle::DataType &seq_lens_decoder_dtype,
|
const paddle::DataType &seq_lens_decoder_dtype,
|
||||||
const paddle::DataType &seq_lens_encoder_dtype,
|
const paddle::DataType &seq_lens_encoder_dtype,
|
||||||
const paddle::optional<paddle::DataType> &output_padding_offset_dtype) {
|
const paddle::optional<paddle::DataType> &output_padding_offset_dtype) {
|
||||||
return {tmp_out_dtype};
|
return {tmp_out_dtype};
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(rebuild_padding_cpu)
|
PD_BUILD_STATIC_OP(rebuild_padding_cpu)
|
||||||
|
|||||||
@@ -14,28 +14,28 @@
|
|||||||
|
|
||||||
#include "paddle/extension.h"
|
#include "paddle/extension.h"
|
||||||
|
|
||||||
void set_value_by_flags_and_idx(const bool *stop_flags,
|
void set_value_by_flag_and_id(const bool *stop_flags,
|
||||||
int64_t *pre_ids_all,
|
int64_t *pre_ids_all,
|
||||||
const int64_t *input_ids,
|
const int64_t *input_ids,
|
||||||
const int *seq_lens_encoder,
|
const int *seq_lens_encoder,
|
||||||
const int *seq_lens_decoder,
|
const int *seq_lens_decoder,
|
||||||
const int64_t *step_idx,
|
const int64_t *step_idx,
|
||||||
int bs,
|
int bs,
|
||||||
int length,
|
int length,
|
||||||
int length_input_ids) {
|
int length_input_ids) {
|
||||||
for (int bi = 0; bi < bs; bi++) {
|
for (int bi = 0; bi < bs; bi++) {
|
||||||
if (!stop_flags[bi]) {
|
if (!stop_flags[bi]) {
|
||||||
const int seq_len_dec = seq_lens_decoder[bi];
|
const int seq_len_dec = seq_lens_decoder[bi];
|
||||||
const int seq_len_enc = seq_lens_encoder[bi];
|
const int seq_len_enc = seq_lens_encoder[bi];
|
||||||
int64_t *pre_ids_all_now = pre_ids_all + bi * length;
|
int64_t *pre_ids_all_now = pre_ids_all + bi * length;
|
||||||
const int64_t *input_ids_now = input_ids + bi * length_input_ids;
|
const int64_t *input_ids_now = input_ids + bi * length_input_ids;
|
||||||
if (seq_len_dec == 0) {
|
if (seq_len_dec == 0) {
|
||||||
pre_ids_all_now[step_idx[bi]] = input_ids_now[seq_len_enc - 1];
|
pre_ids_all_now[step_idx[bi]] = input_ids_now[seq_len_enc - 1];
|
||||||
} else {
|
} else {
|
||||||
pre_ids_all_now[step_idx[bi]] = input_ids_now[0];
|
pre_ids_all_now[step_idx[bi]] = input_ids_now[0];
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
||||||
@@ -45,12 +45,12 @@ void SetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
|
|||||||
const paddle::Tensor &seq_lens_decoder,
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
const paddle::Tensor &step_idx,
|
const paddle::Tensor &step_idx,
|
||||||
const paddle::Tensor &stop_flags) {
|
const paddle::Tensor &stop_flags) {
|
||||||
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
|
std::vector<int64_t> pre_ids_all_shape = pre_ids_all.shape();
|
||||||
int bs = seq_lens_this_time.shape()[0];
|
int bs = seq_lens_this_time.shape()[0];
|
||||||
int length = pre_ids_all_shape[1];
|
int length = pre_ids_all_shape[1];
|
||||||
int length_input_ids = input_ids.shape()[1];
|
int length_input_ids = input_ids.shape()[1];
|
||||||
|
|
||||||
set_value_by_flags_and_idx(stop_flags.data<bool>(),
|
set_value_by_flag_and_id(stop_flags.data<bool>(),
|
||||||
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
|
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
|
||||||
input_ids.data<int64_t>(),
|
input_ids.data<int64_t>(),
|
||||||
seq_lens_encoder.data<int>(),
|
seq_lens_encoder.data<int>(),
|
||||||
|
|||||||
@@ -21,45 +21,45 @@ void probs_sort(const float *probs,
|
|||||||
float *ProbsVals,
|
float *ProbsVals,
|
||||||
int vocab_size,
|
int vocab_size,
|
||||||
int bsz) {
|
int bsz) {
|
||||||
float cursum = 0;
|
float cursum = 0;
|
||||||
std::vector<int64_t> elementsIds(vocab_size);
|
std::vector<int64_t> elementsIds(vocab_size);
|
||||||
std::vector<float> elementsProbs(vocab_size);
|
std::vector<float> elementsProbs(vocab_size);
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (int j = 0; j < vocab_size; j++) {
|
for (int j = 0; j < vocab_size; j++) {
|
||||||
elementsIds[j] = j;
|
elementsIds[j] = j;
|
||||||
elementsProbs[j] = probs[j];
|
elementsProbs[j] = probs[j];
|
||||||
}
|
}
|
||||||
x86simdsortStatic::keyvalue_qsort(
|
x86simdsortStatic::keyvalue_qsort(
|
||||||
elementsProbs.data(), elementsIds.data(), vocab_size, false, true);
|
elementsProbs.data(), elementsIds.data(), vocab_size, false, true);
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (int j = 0; j < vocab_size; ++j) {
|
for (int j = 0; j < vocab_size; ++j) {
|
||||||
ProbsVals[j] = elementsProbs[j];
|
ProbsVals[j] = elementsProbs[j];
|
||||||
ProbsIds[j] = elementsIds[j];
|
ProbsIds[j] = elementsIds[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::vector<paddle::Tensor> SimdSort(const paddle::Tensor &probs) {
|
std::vector<paddle::Tensor> SimdSort(const paddle::Tensor &probs) {
|
||||||
const int bsz = probs.shape()[0];
|
const int bsz = probs.shape()[0];
|
||||||
const int vocab_size = probs.shape()[1];
|
const int vocab_size = probs.shape()[1];
|
||||||
auto sorted_indices =
|
auto sorted_indices = paddle::empty(
|
||||||
paddle::empty({bsz, vocab_size}, paddle::DataType::INT64, probs.place());
|
{bsz, vocab_size}, paddle::DataType::INT64, probs.place());
|
||||||
auto sorted_probs = paddle::empty(
|
auto sorted_probs = paddle::empty(
|
||||||
{bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place());
|
{bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place());
|
||||||
probs_sort(probs.data<float>(),
|
probs_sort(probs.data<float>(),
|
||||||
const_cast<int64_t *>(sorted_indices.data<int64_t>()),
|
const_cast<int64_t *>(sorted_indices.data<int64_t>()),
|
||||||
const_cast<float *>(sorted_probs.data<float>()),
|
const_cast<float *>(sorted_probs.data<float>()),
|
||||||
vocab_size,
|
vocab_size,
|
||||||
bsz);
|
bsz);
|
||||||
return {sorted_indices, sorted_probs};
|
return {sorted_indices, sorted_probs};
|
||||||
}
|
}
|
||||||
std::vector<std::vector<int64_t>> SimdSortInferShape(
|
std::vector<std::vector<int64_t>> SimdSortInferShape(
|
||||||
const std::vector<int64_t> &probs_shape) {
|
const std::vector<int64_t> &probs_shape) {
|
||||||
int64_t bsz = probs_shape[0];
|
int64_t bsz = probs_shape[0];
|
||||||
int64_t vocab_size = probs_shape[1];
|
int64_t vocab_size = probs_shape[1];
|
||||||
return {{bsz, vocab_size}, {bsz, vocab_size}};
|
return {{bsz, vocab_size}, {bsz, vocab_size}};
|
||||||
}
|
}
|
||||||
std::vector<paddle::DataType> SimdSortInferDtype(
|
std::vector<paddle::DataType> SimdSortInferDtype(
|
||||||
const paddle::DataType &probs_dtype) {
|
const paddle::DataType &probs_dtype) {
|
||||||
return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
|
return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
|
||||||
}
|
}
|
||||||
PD_BUILD_STATIC_OP(simd_sort)
|
PD_BUILD_STATIC_OP(simd_sort)
|
||||||
.Inputs({"probs"})
|
.Inputs({"probs"})
|
||||||
|
|||||||
@@ -16,23 +16,23 @@
|
|||||||
#include "paddle/extension.h"
|
#include "paddle/extension.h"
|
||||||
|
|
||||||
std::vector<paddle::Tensor> SimdSort(const paddle::Tensor &probs) {
|
std::vector<paddle::Tensor> SimdSort(const paddle::Tensor &probs) {
|
||||||
const int bsz = probs.shape()[0];
|
const int bsz = probs.shape()[0];
|
||||||
const int vocab_size = probs.shape()[1];
|
const int vocab_size = probs.shape()[1];
|
||||||
auto sorted_indices =
|
auto sorted_indices = paddle::empty(
|
||||||
paddle::empty({bsz, vocab_size}, paddle::DataType::INT64, probs.place());
|
{bsz, vocab_size}, paddle::DataType::INT64, probs.place());
|
||||||
auto sorted_probs = paddle::empty(
|
auto sorted_probs = paddle::empty(
|
||||||
{bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place());
|
{bsz, vocab_size}, paddle::DataType::FLOAT32, probs.place());
|
||||||
return {sorted_indices, sorted_probs};
|
return {sorted_indices, sorted_probs};
|
||||||
}
|
}
|
||||||
std::vector<std::vector<int64_t>> SimdSortInferShape(
|
std::vector<std::vector<int64_t>> SimdSortInferShape(
|
||||||
const std::vector<int64_t> &probs_shape) {
|
const std::vector<int64_t> &probs_shape) {
|
||||||
int64_t bsz = probs_shape[0];
|
int64_t bsz = probs_shape[0];
|
||||||
int64_t vocab_size = probs_shape[1];
|
int64_t vocab_size = probs_shape[1];
|
||||||
return {{bsz, vocab_size}, {bsz, vocab_size}};
|
return {{bsz, vocab_size}, {bsz, vocab_size}};
|
||||||
}
|
}
|
||||||
std::vector<paddle::DataType> SimdSortInferDtype(
|
std::vector<paddle::DataType> SimdSortInferDtype(
|
||||||
const paddle::DataType &probs_dtype) {
|
const paddle::DataType &probs_dtype) {
|
||||||
return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
|
return {paddle::DataType::INT64, paddle::DataType::FLOAT32};
|
||||||
}
|
}
|
||||||
PD_BUILD_STATIC_OP(simd_sort)
|
PD_BUILD_STATIC_OP(simd_sort)
|
||||||
.Inputs({"probs"})
|
.Inputs({"probs"})
|
||||||
|
|||||||
@@ -18,18 +18,14 @@
|
|||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include "paddle/extension.h"
|
#include "paddle/extension.h"
|
||||||
|
|
||||||
#ifndef PD_BUILD_STATIC_OP
|
|
||||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
bool is_in_end(const int64_t id, const int64_t *end_ids, int length) {
|
bool is_in_end(const int64_t id, const int64_t *end_ids, int length) {
|
||||||
bool flag = false;
|
bool flag = false;
|
||||||
for (int i = 0; i < length; i++) {
|
for (int i = 0; i < length; i++) {
|
||||||
if (id == end_ids[i]) {
|
if (id == end_ids[i]) {
|
||||||
return true;
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
return flag;
|
||||||
return flag;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_value_by_flags(bool *stop_flags,
|
void set_value_by_flags(bool *stop_flags,
|
||||||
@@ -40,23 +36,21 @@ void set_value_by_flags(bool *stop_flags,
|
|||||||
const int bs,
|
const int bs,
|
||||||
const int end_length,
|
const int end_length,
|
||||||
bool beam_search) {
|
bool beam_search) {
|
||||||
for (int bi = 0; bi < bs; bi++) {
|
for (int bi = 0; bi < bs; bi++) {
|
||||||
if (stop_flags[bi]) {
|
if (stop_flags[bi]) {
|
||||||
if ((seq_lens[bi] == 0)) {
|
if ((seq_lens[bi] == 0)) {
|
||||||
topk_ids[bi] = -1;
|
topk_ids[bi] = -1;
|
||||||
} else {
|
} else {
|
||||||
topk_ids[bi] = end_ids[0];
|
topk_ids[bi] = end_ids[0];
|
||||||
next_tokens[bi] = end_ids[0];
|
next_tokens[bi] = end_ids[0];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
next_tokens[bi] = topk_ids[bi];
|
next_tokens[bi] = topk_ids[bi];
|
||||||
|
}
|
||||||
|
if (!beam_search && is_in_end(topk_ids[bi], end_ids, end_length)) {
|
||||||
|
stop_flags[bi] = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (!beam_search && is_in_end(topk_ids[bi], end_ids, end_length)) {
|
|
||||||
stop_flags[bi] = true;
|
|
||||||
topk_ids[bi] = end_ids[0];
|
|
||||||
next_tokens[bi] = end_ids[0];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
||||||
@@ -65,17 +59,17 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
|||||||
const paddle::Tensor &end_ids,
|
const paddle::Tensor &end_ids,
|
||||||
const paddle::Tensor &next_tokens,
|
const paddle::Tensor &next_tokens,
|
||||||
const bool beam_search) {
|
const bool beam_search) {
|
||||||
std::vector<int64_t> shape = topk_ids.shape();
|
std::vector<int64_t> shape = topk_ids.shape();
|
||||||
int64_t bs_now = shape[0];
|
int64_t bs_now = shape[0];
|
||||||
int64_t end_length = end_ids.shape()[0];
|
int64_t end_length = end_ids.shape()[0];
|
||||||
set_value_by_flags(const_cast<bool *>(stop_flags.data<bool>()),
|
set_value_by_flags(const_cast<bool *>(stop_flags.data<bool>()),
|
||||||
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
const_cast<int64_t *>(topk_ids.data<int64_t>()),
|
||||||
const_cast<int64_t *>(next_tokens.data<int64_t>()),
|
const_cast<int64_t *>(next_tokens.data<int64_t>()),
|
||||||
end_ids.data<int64_t>(),
|
end_ids.data<int64_t>(),
|
||||||
seq_lens.data<int>(),
|
seq_lens.data<int>(),
|
||||||
bs_now,
|
bs_now,
|
||||||
end_length,
|
end_length,
|
||||||
false);
|
false);
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(set_stop_value_multi_ends_cpu)
|
PD_BUILD_STATIC_OP(set_stop_value_multi_ends_cpu)
|
||||||
|
|||||||
@@ -23,16 +23,16 @@ void min_length_logits_process(float *logits,
|
|||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
const int64_t end_length) {
|
const int64_t end_length) {
|
||||||
for (int bi = 0; bi < bs; ++bi) {
|
for (int bi = 0; bi < bs; ++bi) {
|
||||||
if (cur_len[bi] < 0) {
|
if (cur_len[bi] < 0) {
|
||||||
continue;
|
continue;
|
||||||
|
}
|
||||||
|
if (cur_len[bi] < min_len[bi]) {
|
||||||
|
for (int i = 0; i < end_length; ++i) {
|
||||||
|
logits[bi * length + eos_token_id[i]] = -1e10;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (cur_len[bi] < min_len[bi]) {
|
|
||||||
for (int i = 0; i < end_length; ++i) {
|
|
||||||
logits[bi * length + eos_token_id[i]] = -1e10;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void update_repeat_times(const int64_t *pre_ids,
|
void update_repeat_times(const int64_t *pre_ids,
|
||||||
@@ -41,20 +41,20 @@ void update_repeat_times(const int64_t *pre_ids,
|
|||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
const int64_t length_id) {
|
const int64_t length_id) {
|
||||||
for (int bi = 0; bi < bs; ++bi) {
|
for (int bi = 0; bi < bs; ++bi) {
|
||||||
if (cur_len[bi] < 0) {
|
if (cur_len[bi] < 0) {
|
||||||
continue;
|
continue;
|
||||||
|
}
|
||||||
|
const int64_t *pre_ids_now = pre_ids + bi * length_id;
|
||||||
|
int *repeat_times_now = repeat_times + bi * length;
|
||||||
|
for (int i = 0; i < length_id; i++) {
|
||||||
|
int64_t id = pre_ids_now[i];
|
||||||
|
if (id < 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
repeat_times_now[id] += 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
const int64_t *pre_ids_now = pre_ids + bi * length_id;
|
|
||||||
int *repeat_times_now = repeat_times + bi * length;
|
|
||||||
for (int i = 0; i < length_id; i++) {
|
|
||||||
int64_t id = pre_ids_now[i];
|
|
||||||
if (id < 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
repeat_times_now[id] += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void update_value_by_repeat_times(const int *repeat_times,
|
void update_value_by_repeat_times(const int *repeat_times,
|
||||||
@@ -65,22 +65,24 @@ void update_value_by_repeat_times(const int *repeat_times,
|
|||||||
float *logits,
|
float *logits,
|
||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length) {
|
const int64_t length) {
|
||||||
for (int bi = 0; bi < bs; ++bi) {
|
for (int bi = 0; bi < bs; ++bi) {
|
||||||
float *logits_now = logits + bi * length;
|
float *logits_now = logits + bi * length;
|
||||||
const int *repeat_times_now = repeat_times + bi * length;
|
const int *repeat_times_now = repeat_times + bi * length;
|
||||||
float alpha = static_cast<float>(penalty_scores[bi]);
|
float alpha = static_cast<float>(penalty_scores[bi]);
|
||||||
float beta = static_cast<float>(frequency_score[bi]);
|
float beta = static_cast<float>(frequency_score[bi]);
|
||||||
float gamma = static_cast<float>(presence_score[bi]);
|
float gamma = static_cast<float>(presence_score[bi]);
|
||||||
for (int i = 0; i < length; ++i) {
|
for (int i = 0; i < length; ++i) {
|
||||||
int times = repeat_times_now[i];
|
int times = repeat_times_now[i];
|
||||||
float logit_now = static_cast<float>(logits_now[i]);
|
float logit_now = static_cast<float>(logits_now[i]);
|
||||||
if (times == 0) {
|
if (times == 0) {
|
||||||
logits_now[i] = static_cast<float>(logit_now / temperatures[bi]);
|
logits_now[i] =
|
||||||
}
|
static_cast<float>(logit_now / temperatures[bi]);
|
||||||
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
|
}
|
||||||
logits_now[i] = static_cast<float>(logit_now - times * beta - gamma);
|
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
|
||||||
|
logits_now[i] =
|
||||||
|
static_cast<float>(logit_now - times * beta - gamma);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ban_bad_words(float *logits,
|
void ban_bad_words(float *logits,
|
||||||
@@ -88,14 +90,15 @@ void ban_bad_words(float *logits,
|
|||||||
const int64_t bs,
|
const int64_t bs,
|
||||||
const int64_t length,
|
const int64_t length,
|
||||||
const int64_t bad_words_length) {
|
const int64_t bad_words_length) {
|
||||||
for (int bi = 0; bi < bs; ++bi) {
|
for (int bi = 0; bi < bs; ++bi) {
|
||||||
float *logits_now = logits + bi * length;
|
float *logits_now = logits + bi * length;
|
||||||
for (int bwid = 0; bwid < bad_words_length; ++bwid) {
|
for (int bwid = 0; bwid < bad_words_length; ++bwid) {
|
||||||
const int64_t bad_words_token_id = bad_words_list[bwid];
|
const int64_t bad_words_token_id = bad_words_list[bwid];
|
||||||
if (bad_words_token_id >= length || bad_words_token_id < 0) continue;
|
if (bad_words_token_id >= length || bad_words_token_id < 0)
|
||||||
logits_now[bad_words_token_id] = -1e10;
|
continue;
|
||||||
|
logits_now[bad_words_token_id] = -1e10;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <paddle::DataType D>
|
template <paddle::DataType D>
|
||||||
@@ -109,44 +112,44 @@ void token_penalty_multi_scores_kernel(const paddle::Tensor &pre_ids,
|
|||||||
const paddle::Tensor &cur_len,
|
const paddle::Tensor &cur_len,
|
||||||
const paddle::Tensor &min_len,
|
const paddle::Tensor &min_len,
|
||||||
const paddle::Tensor &eos_token_id) {
|
const paddle::Tensor &eos_token_id) {
|
||||||
std::vector<int64_t> shape = logits.shape();
|
std::vector<int64_t> shape = logits.shape();
|
||||||
auto repeat_times =
|
auto repeat_times =
|
||||||
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
|
paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
|
||||||
int64_t bs = shape[0];
|
int64_t bs = shape[0];
|
||||||
int64_t length = shape[1];
|
int64_t length = shape[1];
|
||||||
int64_t length_id = pre_ids.shape()[1];
|
int64_t length_id = pre_ids.shape()[1];
|
||||||
int64_t end_length = eos_token_id.shape()[0];
|
int64_t end_length = eos_token_id.shape()[0];
|
||||||
int64_t length_bad_words = bad_tokens.shape()[0];
|
int64_t length_bad_words = bad_tokens.shape()[0];
|
||||||
|
|
||||||
min_length_logits_process(const_cast<float *>(logits.data<float>()),
|
min_length_logits_process(const_cast<float *>(logits.data<float>()),
|
||||||
cur_len.data<int64_t>(),
|
cur_len.data<int64_t>(),
|
||||||
min_len.data<int64_t>(),
|
min_len.data<int64_t>(),
|
||||||
eos_token_id.data<int64_t>(),
|
eos_token_id.data<int64_t>(),
|
||||||
bs,
|
bs,
|
||||||
length,
|
length,
|
||||||
end_length);
|
end_length);
|
||||||
|
|
||||||
update_repeat_times(pre_ids.data<int64_t>(),
|
update_repeat_times(pre_ids.data<int64_t>(),
|
||||||
cur_len.data<int64_t>(),
|
cur_len.data<int64_t>(),
|
||||||
repeat_times.data<int>(),
|
repeat_times.data<int>(),
|
||||||
bs,
|
bs,
|
||||||
length,
|
length,
|
||||||
length_id);
|
length_id);
|
||||||
|
|
||||||
update_value_by_repeat_times(repeat_times.data<int>(),
|
update_value_by_repeat_times(repeat_times.data<int>(),
|
||||||
penalty_scores.data<float>(),
|
penalty_scores.data<float>(),
|
||||||
frequency_score.data<float>(),
|
frequency_score.data<float>(),
|
||||||
presence_score.data<float>(),
|
presence_score.data<float>(),
|
||||||
temperatures.data<float>(),
|
temperatures.data<float>(),
|
||||||
const_cast<float *>(logits.data<float>()),
|
const_cast<float *>(logits.data<float>()),
|
||||||
bs,
|
bs,
|
||||||
length);
|
length);
|
||||||
|
|
||||||
ban_bad_words(const_cast<float *>(logits.data<float>()),
|
ban_bad_words(const_cast<float *>(logits.data<float>()),
|
||||||
bad_tokens.data<int64_t>(),
|
bad_tokens.data<int64_t>(),
|
||||||
bs,
|
bs,
|
||||||
length,
|
length,
|
||||||
length_bad_words);
|
length_bad_words);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
||||||
@@ -159,17 +162,17 @@ void TokenPenaltyMultiScores(const paddle::Tensor &pre_ids,
|
|||||||
const paddle::Tensor &cur_len,
|
const paddle::Tensor &cur_len,
|
||||||
const paddle::Tensor &min_len,
|
const paddle::Tensor &min_len,
|
||||||
const paddle::Tensor &eos_token_id) {
|
const paddle::Tensor &eos_token_id) {
|
||||||
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT32>(
|
return token_penalty_multi_scores_kernel<paddle::DataType::FLOAT32>(
|
||||||
pre_ids,
|
pre_ids,
|
||||||
logits,
|
logits,
|
||||||
penalty_scores,
|
penalty_scores,
|
||||||
frequency_scores,
|
frequency_scores,
|
||||||
presence_scores,
|
presence_scores,
|
||||||
temperatures,
|
temperatures,
|
||||||
bad_tokens,
|
bad_tokens,
|
||||||
cur_len,
|
cur_len,
|
||||||
min_len,
|
min_len,
|
||||||
eos_token_id);
|
eos_token_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(get_token_penalty_multi_scores_cpu)
|
PD_BUILD_STATIC_OP(get_token_penalty_multi_scores_cpu)
|
||||||
|
|||||||
@@ -24,50 +24,50 @@ void update_inputs_kernel(bool *not_need_stop,
|
|||||||
const int64_t *next_tokens,
|
const int64_t *next_tokens,
|
||||||
const int bsz,
|
const int bsz,
|
||||||
const int input_ids_stride) {
|
const int input_ids_stride) {
|
||||||
int64_t stop_sum = 0;
|
int64_t stop_sum = 0;
|
||||||
for (int bi = 0; bi < bsz; ++bi) {
|
for (int bi = 0; bi < bsz; ++bi) {
|
||||||
bool stop_flag_now = false;
|
bool stop_flag_now = false;
|
||||||
int64_t stop_flag_now_int = 0;
|
int64_t stop_flag_now_int = 0;
|
||||||
stop_flag_now = stop_flags[bi];
|
stop_flag_now = stop_flags[bi];
|
||||||
stop_flag_now_int = static_cast<int64_t>(stop_flag_now);
|
stop_flag_now_int = static_cast<int64_t>(stop_flag_now);
|
||||||
auto seq_len_this_time = seq_lens_this_time[bi];
|
auto seq_len_this_time = seq_lens_this_time[bi];
|
||||||
auto seq_len_encoder = seq_lens_encoder[bi];
|
auto seq_len_encoder = seq_lens_encoder[bi];
|
||||||
auto seq_len_decoder = seq_lens_decoder[bi];
|
auto seq_len_decoder = seq_lens_decoder[bi];
|
||||||
seq_lens_decoder[bi] =
|
seq_lens_decoder[bi] =
|
||||||
stop_flag_now
|
stop_flag_now ? 0
|
||||||
? 0
|
: (seq_len_decoder == 0 ? seq_len_encoder
|
||||||
: (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1);
|
: seq_len_decoder + 1);
|
||||||
seq_lens_this_time[bi] = stop_flag_now ? 0 : 1;
|
seq_lens_this_time[bi] = stop_flag_now ? 0 : 1;
|
||||||
seq_lens_encoder[bi] = 0;
|
seq_lens_encoder[bi] = 0;
|
||||||
int64_t *input_ids_now = input_ids + bi * input_ids_stride;
|
int64_t *input_ids_now = input_ids + bi * input_ids_stride;
|
||||||
input_ids_now[0] = next_tokens[bi];
|
input_ids_now[0] = next_tokens[bi];
|
||||||
stop_sum += stop_flag_now_int;
|
stop_sum += stop_flag_now_int;
|
||||||
}
|
}
|
||||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
void UpdateInputs(const paddle::Tensor &stop_flags,
|
void UpdateInputes(const paddle::Tensor &stop_flags,
|
||||||
const paddle::Tensor ¬_need_stop,
|
const paddle::Tensor ¬_need_stop,
|
||||||
const paddle::Tensor &seq_lens_this_time,
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
const paddle::Tensor &seq_lens_encoder,
|
const paddle::Tensor &seq_lens_encoder,
|
||||||
const paddle::Tensor &seq_lens_decoder,
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
const paddle::Tensor &input_ids,
|
const paddle::Tensor &input_ids,
|
||||||
const paddle::Tensor &stop_nums,
|
const paddle::Tensor &stop_nums,
|
||||||
const paddle::Tensor &next_tokens,
|
const paddle::Tensor &next_tokens,
|
||||||
const paddle::Tensor &is_block_step) {
|
const paddle::Tensor &is_block_step) {
|
||||||
const int bsz = input_ids.shape()[0];
|
const int bsz = input_ids.shape()[0];
|
||||||
const int input_ids_stride = input_ids.shape()[1];
|
const int input_ids_stride = input_ids.shape()[1];
|
||||||
update_inputs_kernel(const_cast<bool *>(not_need_stop.data<bool>()),
|
update_inputs_kernel(const_cast<bool *>(not_need_stop.data<bool>()),
|
||||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||||
const_cast<int64_t *>(input_ids.data<int64_t>()),
|
const_cast<int64_t *>(input_ids.data<int64_t>()),
|
||||||
stop_nums.data<int64_t>(),
|
stop_nums.data<int64_t>(),
|
||||||
stop_flags.data<bool>(),
|
stop_flags.data<bool>(),
|
||||||
is_block_step.data<bool>(),
|
is_block_step.data<bool>(),
|
||||||
next_tokens.data<int64_t>(),
|
next_tokens.data<int64_t>(),
|
||||||
bsz,
|
bsz,
|
||||||
input_ids_stride);
|
input_ids_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(update_inputs_cpu)
|
PD_BUILD_STATIC_OP(update_inputs_cpu)
|
||||||
@@ -90,4 +90,4 @@ PD_BUILD_STATIC_OP(update_inputs_cpu)
|
|||||||
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
{"seq_lens_encoder", "seq_lens_encoder_out"},
|
||||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||||
{"input_ids", "input_ids_out"}})
|
{"input_ids", "input_ids_out"}})
|
||||||
.SetKernelFn(PD_KERNEL(UpdateInputs));
|
.SetKernelFn(PD_KERNEL(UpdateInputes));
|
||||||
|
|||||||
@@ -45,18 +45,18 @@ std::vector<paddle::Tensor> InvokeAllLLaMALayer(
|
|||||||
int maxPositions,
|
int maxPositions,
|
||||||
int maxPosEmbed,
|
int maxPosEmbed,
|
||||||
int intermediateSize) {
|
int intermediateSize) {
|
||||||
auto out = paddle::empty_like(input);
|
auto out = paddle::empty_like(input);
|
||||||
return {out};
|
return {out};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> AllLLaMALayerInferShape(
|
std::vector<std::vector<int64_t>> AllLLaMALayerInferShape(
|
||||||
std::vector<int64_t> x_shape) {
|
std::vector<int64_t> x_shape) {
|
||||||
return {x_shape};
|
return {x_shape};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::DataType> AllLLaMALayerInferDtype(
|
std::vector<paddle::DataType> AllLLaMALayerInferDtype(
|
||||||
paddle::DataType x_dtype) {
|
paddle::DataType x_dtype) {
|
||||||
return {x_dtype};
|
return {x_dtype};
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(xft_llama_all_layer)
|
PD_BUILD_STATIC_OP(xft_llama_all_layer)
|
||||||
|
|||||||
@@ -16,20 +16,20 @@
|
|||||||
#include "paddle/extension.h"
|
#include "paddle/extension.h"
|
||||||
|
|
||||||
std::vector<paddle::Tensor> XftGreedySearch(const paddle::Tensor &probs) {
|
std::vector<paddle::Tensor> XftGreedySearch(const paddle::Tensor &probs) {
|
||||||
const int bsz = probs.shape()[0];
|
const int bsz = probs.shape()[0];
|
||||||
const int vocab_size = probs.shape()[1];
|
const int vocab_size = probs.shape()[1];
|
||||||
auto next_tokens =
|
auto next_tokens =
|
||||||
paddle::empty({bsz, 1}, paddle::DataType::INT64, probs.place());
|
paddle::empty({bsz, 1}, paddle::DataType::INT64, probs.place());
|
||||||
return {next_tokens};
|
return {next_tokens};
|
||||||
}
|
}
|
||||||
std::vector<std::vector<int64_t>> XftGreedySearchInferShape(
|
std::vector<std::vector<int64_t>> XftGreedySearchInferShape(
|
||||||
const std::vector<int64_t> &probs_shape) {
|
const std::vector<int64_t> &probs_shape) {
|
||||||
int64_t bsz = probs_shape[0];
|
int64_t bsz = probs_shape[0];
|
||||||
return {{bsz, 1}};
|
return {{bsz, 1}};
|
||||||
}
|
}
|
||||||
std::vector<paddle::DataType> XftGreedySearchInferDtype(
|
std::vector<paddle::DataType> XftGreedySearchInferDtype(
|
||||||
const paddle::DataType &probs_dtype) {
|
const paddle::DataType &probs_dtype) {
|
||||||
return {paddle::DataType::INT64};
|
return {paddle::DataType::INT64};
|
||||||
}
|
}
|
||||||
PD_BUILD_STATIC_OP(xft_greedy_search)
|
PD_BUILD_STATIC_OP(xft_greedy_search)
|
||||||
.Inputs({"probs"})
|
.Inputs({"probs"})
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ void AppendAttentionKernel(
|
|||||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||||
const paddle::Tensor& decoder_num_blocks,
|
const paddle::Tensor& decoder_num_blocks,
|
||||||
const paddle::Tensor& set_max_lengths,
|
const paddle::Tensor& set_max_lengths,
|
||||||
|
const paddle::Tensor& max_len_kv,
|
||||||
paddle::Tensor& fmha_out,
|
paddle::Tensor& fmha_out,
|
||||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||||
@@ -72,10 +73,10 @@ void AppendAttentionKernel(
|
|||||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||||
|
const paddle::optional<paddle::Tensor>& mask_offset,
|
||||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
const paddle::optional<paddle::Tensor>& sinks,
|
|
||||||
const float rms_norm_eps,
|
const float rms_norm_eps,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const bool use_neox_rotary_style,
|
const bool use_neox_rotary_style,
|
||||||
@@ -90,8 +91,7 @@ void AppendAttentionKernel(
|
|||||||
const int encoder_max_partition_size,
|
const int encoder_max_partition_size,
|
||||||
const int speculate_max_draft_token_num,
|
const int speculate_max_draft_token_num,
|
||||||
const bool causal,
|
const bool causal,
|
||||||
const bool speculate_decoder,
|
const bool speculate_decoder) {
|
||||||
const int sliding_window) {
|
|
||||||
typedef PDTraits<D> traits_;
|
typedef PDTraits<D> traits_;
|
||||||
typedef typename traits_::DataType DataType_;
|
typedef typename traits_::DataType DataType_;
|
||||||
typedef typename traits_::data_t data_t;
|
typedef typename traits_::data_t data_t;
|
||||||
@@ -103,7 +103,6 @@ void AppendAttentionKernel(
|
|||||||
int max_dec_len_this_time = set_max_lengths.data<int>()[2];
|
int max_dec_len_this_time = set_max_lengths.data<int>()[2];
|
||||||
int max_enc_dec_len_this_time = set_max_lengths.data<int>()[3];
|
int max_enc_dec_len_this_time = set_max_lengths.data<int>()[3];
|
||||||
int max_just_dec_len_this_time = set_max_lengths.data<int>()[4];
|
int max_just_dec_len_this_time = set_max_lengths.data<int>()[4];
|
||||||
int max_kv_len_this_time = set_max_lengths.data<int>()[8];
|
|
||||||
|
|
||||||
auto main_stream = qkv.stream();
|
auto main_stream = qkv.stream();
|
||||||
static cudaEvent_t main_event;
|
static cudaEvent_t main_event;
|
||||||
@@ -141,13 +140,12 @@ void AppendAttentionKernel(
|
|||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
attn_mask,
|
attn_mask,
|
||||||
cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales : cache_k_dequant_scales,
|
cache_k_dequant_scales,
|
||||||
cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales : cache_v_dequant_scales,
|
cache_v_dequant_scales,
|
||||||
cache_k_zp,
|
cache_k_zp,
|
||||||
cache_v_zp,
|
cache_v_zp,
|
||||||
out_linear_shifts,
|
out_linear_shifts,
|
||||||
out_linear_smooths,
|
out_linear_smooths,
|
||||||
sinks,
|
|
||||||
seq_lens_this_time,
|
seq_lens_this_time,
|
||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -171,8 +169,7 @@ void AppendAttentionKernel(
|
|||||||
lambda_is_decoder,
|
lambda_is_decoder,
|
||||||
lambda_enable_prefill,
|
lambda_enable_prefill,
|
||||||
lambda_stream,
|
lambda_stream,
|
||||||
&fmha_out,
|
&fmha_out);
|
||||||
sliding_window);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (max_enc_len_this_time > 0) {
|
if (max_enc_len_this_time > 0) {
|
||||||
@@ -248,6 +245,7 @@ void AppendAttentionKernel(
|
|||||||
|
|
||||||
if (max_just_dec_len_this_time > 0) {
|
if (max_just_dec_len_this_time > 0) {
|
||||||
int decoder_num_blocks_data = decoder_num_blocks.data<int>()[0];
|
int decoder_num_blocks_data = decoder_num_blocks.data<int>()[0];
|
||||||
|
int max_len_kv_data = max_len_kv.data<int>()[0];
|
||||||
|
|
||||||
cudaStream_t exec_stream;
|
cudaStream_t exec_stream;
|
||||||
if (max_enc_len_this_time > 0) {
|
if (max_enc_len_this_time > 0) {
|
||||||
@@ -275,15 +273,11 @@ void AppendAttentionKernel(
|
|||||||
cache_v_zp,
|
cache_v_zp,
|
||||||
cache_quant_type_str,
|
cache_quant_type_str,
|
||||||
use_neox_rotary_style,
|
use_neox_rotary_style,
|
||||||
rope_3d,
|
|
||||||
max_input_length,
|
max_input_length,
|
||||||
exec_stream,
|
exec_stream,
|
||||||
&qkv_out,
|
&qkv_out,
|
||||||
const_cast<paddle::Tensor*>(&key_cache),
|
const_cast<paddle::Tensor*>(&key_cache),
|
||||||
const_cast<paddle::Tensor*>(&value_cache),
|
const_cast<paddle::Tensor*>(&value_cache));
|
||||||
q_norm_weight,
|
|
||||||
k_norm_weight,
|
|
||||||
rms_norm_eps);
|
|
||||||
} else {
|
} else {
|
||||||
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
|
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
|
||||||
meta_data,
|
meta_data,
|
||||||
@@ -302,15 +296,11 @@ void AppendAttentionKernel(
|
|||||||
cache_v_zp,
|
cache_v_zp,
|
||||||
cache_quant_type_str,
|
cache_quant_type_str,
|
||||||
use_neox_rotary_style,
|
use_neox_rotary_style,
|
||||||
rope_3d,
|
|
||||||
max_input_length,
|
max_input_length,
|
||||||
exec_stream,
|
exec_stream,
|
||||||
&qkv_out,
|
&qkv_out,
|
||||||
const_cast<paddle::Tensor*>(&key_cache),
|
const_cast<paddle::Tensor*>(&key_cache),
|
||||||
const_cast<paddle::Tensor*>(&value_cache),
|
const_cast<paddle::Tensor*>(&value_cache));
|
||||||
q_norm_weight,
|
|
||||||
k_norm_weight,
|
|
||||||
rms_norm_eps);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (qkv_out_scales) {
|
if (qkv_out_scales) {
|
||||||
@@ -319,6 +309,7 @@ void AppendAttentionKernel(
|
|||||||
qkv, // [token_num, num_heads, head_dim]
|
qkv, // [token_num, num_heads, head_dim]
|
||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
block_tables,
|
block_tables,
|
||||||
rotary_embs,
|
rotary_embs,
|
||||||
@@ -345,6 +336,7 @@ void AppendAttentionKernel(
|
|||||||
qkv_out, // [token_num, num_heads, head_dim]
|
qkv_out, // [token_num, num_heads, head_dim]
|
||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
block_tables,
|
block_tables,
|
||||||
rotary_embs,
|
rotary_embs,
|
||||||
@@ -373,20 +365,20 @@ void AppendAttentionKernel(
|
|||||||
case paddle::DataType::INT8:{
|
case paddle::DataType::INT8:{
|
||||||
int8_t tmp;
|
int8_t tmp;
|
||||||
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
|
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
|
||||||
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
|
decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case paddle::DataType::FLOAT8_E4M3FN:{
|
case paddle::DataType::FLOAT8_E4M3FN:{
|
||||||
phi::dtype::float8_e4m3fn tmp;
|
phi::dtype::float8_e4m3fn tmp;
|
||||||
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
|
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
|
||||||
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
|
decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
data_t tmp;
|
data_t tmp;
|
||||||
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
|
dispatch_CascadeAppendAttentionKernel(tmp, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks_data,
|
||||||
decoder_block_shape_q, max_kv_len_this_time, !speculate_decoder, !speculate_decoder, exec_stream);
|
decoder_block_shape_q, max_len_kv_data, !speculate_decoder, !speculate_decoder, exec_stream);
|
||||||
}
|
}
|
||||||
if (max_enc_len_this_time > 0) {
|
if (max_enc_len_this_time > 0) {
|
||||||
cudaEventRecord(decoder_event, exec_stream);
|
cudaEventRecord(decoder_event, exec_stream);
|
||||||
@@ -415,6 +407,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||||
const paddle::Tensor& decoder_num_blocks,
|
const paddle::Tensor& decoder_num_blocks,
|
||||||
const paddle::Tensor& set_max_lengths,
|
const paddle::Tensor& set_max_lengths,
|
||||||
|
const paddle::Tensor& max_len_kv,
|
||||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||||
const paddle::optional<paddle::Tensor>& qkv_bias,
|
const paddle::optional<paddle::Tensor>& qkv_bias,
|
||||||
@@ -431,7 +424,6 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
const paddle::optional<paddle::Tensor>& sinks,
|
|
||||||
const float rms_norm_eps,
|
const float rms_norm_eps,
|
||||||
const std::string& compute_dtype,
|
const std::string& compute_dtype,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
@@ -447,8 +439,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
const int encoder_max_partition_size,
|
const int encoder_max_partition_size,
|
||||||
const int speculate_max_draft_token_num,
|
const int speculate_max_draft_token_num,
|
||||||
const bool causal,
|
const bool causal,
|
||||||
const bool speculate_decoder,
|
const bool speculate_decoder) {
|
||||||
const int sliding_window) {
|
|
||||||
AppendAttnMetaData meta_data;
|
AppendAttnMetaData meta_data;
|
||||||
|
|
||||||
const auto& qkv_dims = qkv.dims();
|
const auto& qkv_dims = qkv.dims();
|
||||||
@@ -497,12 +488,12 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
paddle::Tensor fmha_out;
|
paddle::Tensor fmha_out;
|
||||||
if (out_linear_in_scale > 0.0) {
|
if (out_linear_in_scale > 0.0) {
|
||||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||||
fmha_out = paddle::zeros(
|
fmha_out = GetEmptyTensor(
|
||||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||||
paddle::DataType::INT8,
|
paddle::DataType::INT8,
|
||||||
qkv.place());
|
qkv.place());
|
||||||
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
|
||||||
fmha_out = paddle::zeros(
|
fmha_out = GetEmptyTensor(
|
||||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||||
paddle::DataType::FLOAT8_E4M3FN,
|
paddle::DataType::FLOAT8_E4M3FN,
|
||||||
qkv.place());
|
qkv.place());
|
||||||
@@ -510,7 +501,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
|
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fmha_out = paddle::zeros(
|
fmha_out = GetEmptyTensor(
|
||||||
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
|
||||||
dtype_id,
|
dtype_id,
|
||||||
qkv.place());
|
qkv.place());
|
||||||
@@ -542,6 +533,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
decoder_tile_ids_per_batch,
|
decoder_tile_ids_per_batch,
|
||||||
decoder_num_blocks,
|
decoder_num_blocks,
|
||||||
set_max_lengths,
|
set_max_lengths,
|
||||||
|
max_len_kv,
|
||||||
fmha_out,
|
fmha_out,
|
||||||
rotary_embs,
|
rotary_embs,
|
||||||
attn_mask,
|
attn_mask,
|
||||||
@@ -555,10 +547,10 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
cache_v_zp,
|
cache_v_zp,
|
||||||
out_linear_shifts,
|
out_linear_shifts,
|
||||||
out_linear_smooths,
|
out_linear_smooths,
|
||||||
|
mask_offset,
|
||||||
kv_signal_data,
|
kv_signal_data,
|
||||||
q_norm_weight,
|
q_norm_weight,
|
||||||
k_norm_weight,
|
k_norm_weight,
|
||||||
sinks,
|
|
||||||
rms_norm_eps,
|
rms_norm_eps,
|
||||||
cache_quant_type_str,
|
cache_quant_type_str,
|
||||||
use_neox_rotary_style,
|
use_neox_rotary_style,
|
||||||
@@ -573,8 +565,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
encoder_max_partition_size,
|
encoder_max_partition_size,
|
||||||
speculate_max_draft_token_num,
|
speculate_max_draft_token_num,
|
||||||
causal,
|
causal,
|
||||||
speculate_decoder,
|
speculate_decoder);
|
||||||
sliding_window);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@@ -599,7 +590,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
return {paddle::Tensor{}};
|
return {paddle::Tensor{}};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::Tensor> AppendAttentionWithOutput(
|
void AppendAttentionWithOutput(
|
||||||
const paddle::Tensor& qkv,
|
const paddle::Tensor& qkv,
|
||||||
const paddle::Tensor& key_cache,
|
const paddle::Tensor& key_cache,
|
||||||
const paddle::Tensor& value_cache,
|
const paddle::Tensor& value_cache,
|
||||||
@@ -619,6 +610,7 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
|
|||||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||||
const paddle::Tensor& decoder_num_blocks,
|
const paddle::Tensor& decoder_num_blocks,
|
||||||
const paddle::Tensor& set_max_lengths,
|
const paddle::Tensor& set_max_lengths,
|
||||||
|
const paddle::Tensor& max_len_kv,
|
||||||
paddle::Tensor& fmha_out,
|
paddle::Tensor& fmha_out,
|
||||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||||
@@ -636,7 +628,6 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
|
|||||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
const paddle::optional<paddle::Tensor>& sinks,
|
|
||||||
const float rms_norm_eps,
|
const float rms_norm_eps,
|
||||||
const std::string& compute_dtype,
|
const std::string& compute_dtype,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
@@ -652,8 +643,7 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
|
|||||||
const int encoder_max_partition_size,
|
const int encoder_max_partition_size,
|
||||||
const int speculate_max_draft_token_num,
|
const int speculate_max_draft_token_num,
|
||||||
const bool causal,
|
const bool causal,
|
||||||
const bool speculate_decoder,
|
const bool speculate_decoder) {
|
||||||
const int sliding_window) {
|
|
||||||
AppendAttnMetaData meta_data;
|
AppendAttnMetaData meta_data;
|
||||||
|
|
||||||
const auto& qkv_dims = qkv.dims();
|
const auto& qkv_dims = qkv.dims();
|
||||||
@@ -699,6 +689,7 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
|
|||||||
decoder_tile_ids_per_batch,
|
decoder_tile_ids_per_batch,
|
||||||
decoder_num_blocks,
|
decoder_num_blocks,
|
||||||
set_max_lengths,
|
set_max_lengths,
|
||||||
|
max_len_kv,
|
||||||
fmha_out,
|
fmha_out,
|
||||||
rotary_embs,
|
rotary_embs,
|
||||||
attn_mask,
|
attn_mask,
|
||||||
@@ -712,10 +703,10 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
|
|||||||
cache_v_zp,
|
cache_v_zp,
|
||||||
out_linear_shifts,
|
out_linear_shifts,
|
||||||
out_linear_smooths,
|
out_linear_smooths,
|
||||||
|
mask_offset,
|
||||||
kv_signal_data,
|
kv_signal_data,
|
||||||
q_norm_weight,
|
q_norm_weight,
|
||||||
k_norm_weight,
|
k_norm_weight,
|
||||||
sinks,
|
|
||||||
rms_norm_eps,
|
rms_norm_eps,
|
||||||
cache_quant_type_str,
|
cache_quant_type_str,
|
||||||
use_neox_rotary_style,
|
use_neox_rotary_style,
|
||||||
@@ -730,8 +721,7 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
|
|||||||
encoder_max_partition_size,
|
encoder_max_partition_size,
|
||||||
speculate_max_draft_token_num,
|
speculate_max_draft_token_num,
|
||||||
causal,
|
causal,
|
||||||
speculate_decoder,
|
speculate_decoder);
|
||||||
sliding_window);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
phi::dtype::float16 fp16_dtype;
|
phi::dtype::float16 fp16_dtype;
|
||||||
@@ -765,8 +755,6 @@ std::vector<paddle::Tensor> AppendAttentionWithOutput(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return {fmha_out};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -790,6 +778,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
|||||||
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
|
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
|
||||||
const std::vector<int64_t>& decoder_num_blocks_shape,
|
const std::vector<int64_t>& decoder_num_blocks_shape,
|
||||||
const std::vector<int64_t>& set_max_lengths_shape,
|
const std::vector<int64_t>& set_max_lengths_shape,
|
||||||
|
const std::vector<int64_t>& max_len_kv_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
|
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
|
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& qkv_bias_shape,
|
const paddle::optional<std::vector<int64_t>>& qkv_bias_shape,
|
||||||
@@ -806,7 +795,6 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
|||||||
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
|
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
|
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
|
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& sinks_shape,
|
|
||||||
const float rms_norm_eps,
|
const float rms_norm_eps,
|
||||||
const std::string& compute_dtype,
|
const std::string& compute_dtype,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
@@ -822,8 +810,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
|||||||
const int encoder_max_partition_size,
|
const int encoder_max_partition_size,
|
||||||
const int speculate_max_draft_token_num,
|
const int speculate_max_draft_token_num,
|
||||||
const bool causal,
|
const bool causal,
|
||||||
const bool speculate_decoder,
|
const bool speculate_decoder) {
|
||||||
const int sliding_window) {
|
|
||||||
const int token_num = qkv_shape[0];
|
const int token_num = qkv_shape[0];
|
||||||
const int kv_num_heads = key_cache_shape[1];
|
const int kv_num_heads = key_cache_shape[1];
|
||||||
int head_dim = key_cache_shape[3];
|
int head_dim = key_cache_shape[3];
|
||||||
@@ -855,6 +842,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
|||||||
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
|
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
|
||||||
const paddle::DataType& decoder_num_blocks_dtype,
|
const paddle::DataType& decoder_num_blocks_dtype,
|
||||||
const paddle::DataType& set_max_lengths_dtype,
|
const paddle::DataType& set_max_lengths_dtype,
|
||||||
|
const paddle::DataType& max_len_kv_dtype,
|
||||||
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
|
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
|
||||||
const paddle::optional<paddle::DataType>& attn_mask_dtype,
|
const paddle::optional<paddle::DataType>& attn_mask_dtype,
|
||||||
const paddle::optional<paddle::DataType>& qkv_bias_dtype,
|
const paddle::optional<paddle::DataType>& qkv_bias_dtype,
|
||||||
@@ -871,7 +859,6 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
|||||||
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
|
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
|
||||||
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
|
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
|
||||||
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
|
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
|
||||||
const paddle::optional<paddle::DataType>& sinks_dtype,
|
|
||||||
const float rms_norm_eps,
|
const float rms_norm_eps,
|
||||||
const std::string& compute_dtype,
|
const std::string& compute_dtype,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
@@ -887,8 +874,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
|||||||
const int encoder_max_partition_size,
|
const int encoder_max_partition_size,
|
||||||
const int speculate_max_draft_token_num,
|
const int speculate_max_draft_token_num,
|
||||||
const bool causal,
|
const bool causal,
|
||||||
const bool speculate_decoder,
|
const bool speculate_decoder) {
|
||||||
const int sliding_window) {
|
|
||||||
if (compute_dtype == "bf16") {
|
if (compute_dtype == "bf16") {
|
||||||
if (out_linear_in_scale > 0.0) {
|
if (out_linear_in_scale > 0.0) {
|
||||||
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
|
||||||
@@ -938,6 +924,7 @@ std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
|
|||||||
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
|
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
|
||||||
const std::vector<int64_t>& decoder_num_blocks_shape,
|
const std::vector<int64_t>& decoder_num_blocks_shape,
|
||||||
const std::vector<int64_t>& set_max_lengths_shape,
|
const std::vector<int64_t>& set_max_lengths_shape,
|
||||||
|
const std::vector<int64_t>& max_len_kv_shape,
|
||||||
const std::vector<int64_t>& fmha_out_shape,
|
const std::vector<int64_t>& fmha_out_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
|
const paddle::optional<std::vector<int64_t>>& rotary_embs_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
|
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
|
||||||
@@ -955,7 +942,6 @@ std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
|
|||||||
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
|
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
|
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
|
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& sinks_shape,
|
|
||||||
const float rms_norm_eps,
|
const float rms_norm_eps,
|
||||||
const std::string& compute_dtype,
|
const std::string& compute_dtype,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
@@ -971,8 +957,7 @@ std::vector<std::vector<int64_t>> AppendAttentionWithOutputInferShape(
|
|||||||
const int encoder_max_partition_size,
|
const int encoder_max_partition_size,
|
||||||
const int speculate_max_draft_token_num,
|
const int speculate_max_draft_token_num,
|
||||||
const bool causal,
|
const bool causal,
|
||||||
const bool speculate_decoder,
|
const bool speculate_decoder) {
|
||||||
const int sliding_window) {
|
|
||||||
return {fmha_out_shape};
|
return {fmha_out_shape};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -996,6 +981,7 @@ std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
|
|||||||
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
|
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
|
||||||
const paddle::DataType& decoder_num_blocks_dtype,
|
const paddle::DataType& decoder_num_blocks_dtype,
|
||||||
const paddle::DataType& set_max_lengths_dtype,
|
const paddle::DataType& set_max_lengths_dtype,
|
||||||
|
const paddle::DataType& max_len_kv_dtype,
|
||||||
const paddle::DataType& fmha_out_dtype,
|
const paddle::DataType& fmha_out_dtype,
|
||||||
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
|
const paddle::optional<paddle::DataType>& rotary_embs_dtype,
|
||||||
const paddle::optional<paddle::DataType>& attn_mask_dtype,
|
const paddle::optional<paddle::DataType>& attn_mask_dtype,
|
||||||
@@ -1013,7 +999,6 @@ std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
|
|||||||
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
|
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
|
||||||
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
|
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
|
||||||
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
|
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
|
||||||
const paddle::optional<paddle::DataType>& sinks_dtype,
|
|
||||||
const float rms_norm_eps,
|
const float rms_norm_eps,
|
||||||
const std::string& compute_dtype,
|
const std::string& compute_dtype,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
@@ -1029,8 +1014,7 @@ std::vector<paddle::DataType> AppendAttentionWithOutputInferDtype(
|
|||||||
const int encoder_max_partition_size,
|
const int encoder_max_partition_size,
|
||||||
const int speculate_max_draft_token_num,
|
const int speculate_max_draft_token_num,
|
||||||
const bool causal,
|
const bool causal,
|
||||||
const bool speculate_decoder,
|
const bool speculate_decoder) {
|
||||||
const int sliding_window) {
|
|
||||||
return {fmha_out_dtype};
|
return {fmha_out_dtype};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1056,6 +1040,7 @@ PD_BUILD_STATIC_OP(append_attention)
|
|||||||
"decoder_tile_ids_per_batch",
|
"decoder_tile_ids_per_batch",
|
||||||
"decoder_num_blocks",
|
"decoder_num_blocks",
|
||||||
"set_max_lengths",
|
"set_max_lengths",
|
||||||
|
"max_len_kv",
|
||||||
paddle::Optional("rotary_embs"),
|
paddle::Optional("rotary_embs"),
|
||||||
paddle::Optional("attn_mask"),
|
paddle::Optional("attn_mask"),
|
||||||
paddle::Optional("qkv_bias"),
|
paddle::Optional("qkv_bias"),
|
||||||
@@ -1071,9 +1056,10 @@ PD_BUILD_STATIC_OP(append_attention)
|
|||||||
paddle::Optional("mask_offset"),
|
paddle::Optional("mask_offset"),
|
||||||
paddle::Optional("kv_signal_data"),
|
paddle::Optional("kv_signal_data"),
|
||||||
paddle::Optional("q_norm_weight"),
|
paddle::Optional("q_norm_weight"),
|
||||||
paddle::Optional("k_norm_weight"),
|
paddle::Optional("k_norm_weight")})
|
||||||
paddle::Optional("sinks")})
|
.Outputs({"fmha_out", "key_cache_out", "value_cache_out"})
|
||||||
.Outputs({"fmha_out"})
|
.SetInplaceMap({{"key_cache", "key_cache_out"},
|
||||||
|
{"value_cache", "value_cache_out"}})
|
||||||
.Attrs({"rms_norm_eps: float",
|
.Attrs({"rms_norm_eps: float",
|
||||||
"compute_type: std::string",
|
"compute_type: std::string",
|
||||||
"cache_quant_type: std::string",
|
"cache_quant_type: std::string",
|
||||||
@@ -1090,7 +1076,6 @@ PD_BUILD_STATIC_OP(append_attention)
|
|||||||
"speculate_max_draft_token_num: int",
|
"speculate_max_draft_token_num: int",
|
||||||
"causal: bool",
|
"causal: bool",
|
||||||
"speculate_decoder: bool",
|
"speculate_decoder: bool",
|
||||||
"sliding_window: int",
|
|
||||||
})
|
})
|
||||||
.SetKernelFn(PD_KERNEL(AppendAttention))
|
.SetKernelFn(PD_KERNEL(AppendAttention))
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
|
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
|
||||||
@@ -1116,6 +1101,7 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
|
|||||||
"decoder_tile_ids_per_batch",
|
"decoder_tile_ids_per_batch",
|
||||||
"decoder_num_blocks",
|
"decoder_num_blocks",
|
||||||
"set_max_lengths",
|
"set_max_lengths",
|
||||||
|
"max_len_kv",
|
||||||
"fmha_out",
|
"fmha_out",
|
||||||
paddle::Optional("rotary_embs"),
|
paddle::Optional("rotary_embs"),
|
||||||
paddle::Optional("attn_mask"),
|
paddle::Optional("attn_mask"),
|
||||||
@@ -1132,10 +1118,11 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
|
|||||||
paddle::Optional("mask_offset"),
|
paddle::Optional("mask_offset"),
|
||||||
paddle::Optional("kv_signal_data"),
|
paddle::Optional("kv_signal_data"),
|
||||||
paddle::Optional("q_norm_weight"),
|
paddle::Optional("q_norm_weight"),
|
||||||
paddle::Optional("k_norm_weight"),
|
paddle::Optional("k_norm_weight")})
|
||||||
paddle::Optional("sinks")})
|
.Outputs({"fmha_out_out", "qkv_out", "key_cache_out", "value_cache_out"})
|
||||||
.Outputs({"fmha_out_out"})
|
.SetInplaceMap({{"fmha_out", "fmha_out_out"},
|
||||||
.SetInplaceMap({{"fmha_out", "fmha_out_out"}})
|
{"key_cache", "key_cache_out"},
|
||||||
|
{"value_cache", "value_cache_out"}})
|
||||||
.Attrs({"rms_norm_eps: float",
|
.Attrs({"rms_norm_eps: float",
|
||||||
"compute_type: std::string",
|
"compute_type: std::string",
|
||||||
"cache_quant_type: std::string",
|
"cache_quant_type: std::string",
|
||||||
@@ -1152,7 +1139,6 @@ PD_BUILD_STATIC_OP(append_attention_with_output)
|
|||||||
"speculate_max_draft_token_num: int",
|
"speculate_max_draft_token_num: int",
|
||||||
"causal: bool",
|
"causal: bool",
|
||||||
"speculate_decoder: bool",
|
"speculate_decoder: bool",
|
||||||
"sliding_window: int",
|
|
||||||
})
|
})
|
||||||
.SetKernelFn(PD_KERNEL(AppendAttentionWithOutput))
|
.SetKernelFn(PD_KERNEL(AppendAttentionWithOutput))
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape))
|
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape))
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -77,14 +77,6 @@ struct prefill_softmax_state_t {
|
|||||||
|
|
||||||
__device__ __forceinline__ void normalize() {
|
__device__ __forceinline__ void normalize() {
|
||||||
const T d_t = static_cast<T>(d);
|
const T d_t = static_cast<T>(d);
|
||||||
#pragma unroll
|
|
||||||
for (size_t i = 0; i < vec_size; ++i) {
|
|
||||||
o[i] /= d_t;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ void normalize(float current_sink) {
|
|
||||||
const T d_t = static_cast<T>(d + __expf(current_sink - m));
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (size_t i = 0; i < vec_size; ++i) {
|
for (size_t i = 0; i < vec_size; ++i) {
|
||||||
o[i] /= d_t;
|
o[i] /= d_t;
|
||||||
@@ -392,113 +384,6 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<uint32_t block_size,
|
|
||||||
uint32_t num_frags_z,
|
|
||||||
uint32_t NUM_WARP_Q,
|
|
||||||
typename T>
|
|
||||||
__device__ __forceinline__ void produce_k_dynamic_scale(
|
|
||||||
T* k_smem_scale,
|
|
||||||
T* cache_k_reg,
|
|
||||||
const int* block_table_now,
|
|
||||||
const T* cache_k_scale,
|
|
||||||
const uint32_t kv_idx,
|
|
||||||
const uint32_t kv_num_heads,
|
|
||||||
const uint32_t kv_head_idx,
|
|
||||||
const uint32_t chunk_end
|
|
||||||
) {
|
|
||||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
|
||||||
if constexpr (NUM_WARP_Q == 4) {
|
|
||||||
// 4 warps shared block_size
|
|
||||||
const uint32_t tid = ty * 32 + tx;
|
|
||||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
|
||||||
if (block_id < 0) block_id = 0;
|
|
||||||
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
|
||||||
if (tid < block_size) {
|
|
||||||
k_smem_scale[tid] = cache_k_scale_now[tid];
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
const uint32_t row_id = tx / 4;
|
|
||||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
|
||||||
cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id];
|
|
||||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// 1 warp 32 tokens
|
|
||||||
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
|
|
||||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
|
||||||
if (block_id < 0) block_id = 0;
|
|
||||||
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
|
||||||
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
|
|
||||||
if (kv_idx_this_thread < chunk_end) {
|
|
||||||
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
|
|
||||||
} else {
|
|
||||||
k_smem_scale[ty * 32 + tx] = 0;
|
|
||||||
}
|
|
||||||
__syncwarp();
|
|
||||||
const uint32_t row_id = tx / 4;
|
|
||||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
|
||||||
cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id];
|
|
||||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<uint32_t block_size,
|
|
||||||
uint32_t num_frags_z,
|
|
||||||
uint32_t NUM_WARP_Q,
|
|
||||||
typename T>
|
|
||||||
__device__ __forceinline__ void produce_v_dynamic_scale(
|
|
||||||
T* v_smem_scale,
|
|
||||||
T* cache_v_reg,
|
|
||||||
const int* block_table_now,
|
|
||||||
const T* cache_v_scale,
|
|
||||||
const uint32_t kv_idx,
|
|
||||||
const uint32_t kv_num_heads,
|
|
||||||
const uint32_t kv_head_idx,
|
|
||||||
const uint32_t chunk_end
|
|
||||||
) {
|
|
||||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
|
||||||
|
|
||||||
if constexpr (NUM_WARP_Q == 4) {
|
|
||||||
// 4 warps shared block_size
|
|
||||||
const uint32_t tid = ty * 32 + tx;
|
|
||||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
|
||||||
if (block_id < 0) block_id = 0;
|
|
||||||
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
|
||||||
if (tid < block_size) {
|
|
||||||
v_smem_scale[tid] = cache_v_scale_now[tid];
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
const uint32_t row_id = tx % 4 * 2;
|
|
||||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
|
||||||
cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id];
|
|
||||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1];
|
|
||||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8];
|
|
||||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// 1 warp 32 tokens
|
|
||||||
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
|
|
||||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
|
||||||
if (block_id < 0) block_id = 0;
|
|
||||||
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
|
||||||
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
|
|
||||||
if (kv_idx_this_thread < chunk_end) {
|
|
||||||
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
|
|
||||||
} else {
|
|
||||||
v_smem_scale[ty * 32 + tx] = 0;
|
|
||||||
}
|
|
||||||
__syncwarp();
|
|
||||||
const uint32_t row_id = tx % 4 * 2;
|
|
||||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
|
||||||
cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id];
|
|
||||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1];
|
|
||||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8];
|
|
||||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <SharedMemFillMode fill_mode,
|
template <SharedMemFillMode fill_mode,
|
||||||
uint32_t num_warps,
|
uint32_t num_warps,
|
||||||
uint32_t block_size,
|
uint32_t block_size,
|
||||||
@@ -931,8 +816,7 @@ template <uint32_t num_frags_x,
|
|||||||
typename T,
|
typename T,
|
||||||
typename CacheT,
|
typename CacheT,
|
||||||
bool is_scale_channel_wise = false,
|
bool is_scale_channel_wise = false,
|
||||||
bool IsFP8 = false,
|
bool IsFP8=false>
|
||||||
bool IsDynamicC8 = false>
|
|
||||||
__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
||||||
uint32_t* q_smem_offset_r,
|
uint32_t* q_smem_offset_r,
|
||||||
smem_t* k_smem,
|
smem_t* k_smem,
|
||||||
@@ -976,27 +860,20 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
|||||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
|
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
|
||||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
|
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
|
||||||
// scale zp
|
// scale zp
|
||||||
if constexpr (!IsDynamicC8) {
|
if constexpr (is_scale_channel_wise) {
|
||||||
if constexpr (is_scale_channel_wise) {
|
const int scale_col = (ky * 2 + fy) * 4;
|
||||||
const int scale_col = (ky * 2 + fy) * 4;
|
b_frag_dq_T[0] *= cache_k_scale[scale_col];
|
||||||
b_frag_dq_T[0] *= cache_k_scale[scale_col];
|
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
|
||||||
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
|
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
|
||||||
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
|
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
|
||||||
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
|
b_frag_dq_T[4] *= cache_k_scale[scale_col];
|
||||||
b_frag_dq_T[4] *= cache_k_scale[scale_col];
|
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
|
||||||
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
|
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
|
||||||
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
|
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
|
||||||
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
|
|
||||||
} else {
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
|
||||||
b_frag_dq_T[b_i] *= cache_k_scale[0];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||||
b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4];
|
b_frag_dq_T[b_i] *= cache_k_scale[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@@ -1036,8 +913,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
|||||||
const uint32_t chunk_end,
|
const uint32_t chunk_end,
|
||||||
const uint32_t attn_mask_len,
|
const uint32_t attn_mask_len,
|
||||||
float (*s_frag)[num_frags_z][8],
|
float (*s_frag)[num_frags_z][8],
|
||||||
const int *mask_offset = nullptr,
|
const int *mask_offset = nullptr) {
|
||||||
const int sliding_window = 0) {
|
|
||||||
const uint32_t tx = threadIdx.x;
|
const uint32_t tx = threadIdx.x;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||||
@@ -1053,22 +929,12 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
|||||||
8 * (reg_id / 4) + reg_id % 2;
|
8 * (reg_id / 4) + reg_id % 2;
|
||||||
bool out_of_boundary;
|
bool out_of_boundary;
|
||||||
if (mask_offset) {
|
if (mask_offset) {
|
||||||
out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true;
|
out_of_boundary = q_idx < qo_len ? (kv_idx > mask_offset[q_idx]) : true;
|
||||||
}
|
} else {
|
||||||
else if (sliding_window > 0)
|
|
||||||
{
|
|
||||||
bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - (int)qo_len - sliding_window;
|
|
||||||
out_of_boundary =
|
out_of_boundary =
|
||||||
(causal
|
(causal
|
||||||
? (kv_idx > kv_len + q_idx - qo_len || out_of_window || (kv_idx >= chunk_end))
|
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
||||||
: kv_idx >= chunk_end);
|
: kv_idx >= chunk_end);
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
out_of_boundary =
|
|
||||||
(causal
|
|
||||||
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
|
|
||||||
: kv_idx >= chunk_end);
|
|
||||||
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
|
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
|
||||||
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
|
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
|
||||||
bool mask = attn_mask[mask_idx];
|
bool mask = attn_mask[mask_idx];
|
||||||
@@ -1083,7 +949,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
|
|||||||
s_frag[fx][fz][reg_id] =
|
s_frag[fx][fz][reg_id] =
|
||||||
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
|
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
|
||||||
}
|
}
|
||||||
|
// printf("tid: %d. qk[%u,%u] = %f, mask: %d \n ", threadIdx.x, kv_idx, q_idx, static_cast<float>(s_frag[fx][fz][reg_id]), int(out_of_boundary));
|
||||||
} else {
|
} else {
|
||||||
const uint32_t q_idx = qo_idx_base,
|
const uint32_t q_idx = qo_idx_base,
|
||||||
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
|
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
|
||||||
@@ -1227,9 +1093,7 @@ template <uint32_t num_frags_x,
|
|||||||
uint32_t block_size,
|
uint32_t block_size,
|
||||||
typename T,
|
typename T,
|
||||||
typename CacheT,
|
typename CacheT,
|
||||||
bool is_scale_channel_wise = false,
|
bool is_scale_channel_wise = false, bool IsFP8=false>
|
||||||
bool IsFP8 = false,
|
|
||||||
bool IsDynamicC8 = false>
|
|
||||||
__device__ __forceinline__ void compute_sfm_v_c8(
|
__device__ __forceinline__ void compute_sfm_v_c8(
|
||||||
smem_t* v_smem,
|
smem_t* v_smem,
|
||||||
uint32_t* v_smem_offset_r,
|
uint32_t* v_smem_offset_r,
|
||||||
@@ -1271,28 +1135,16 @@ __device__ __forceinline__ void compute_sfm_v_c8(
|
|||||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||||
// scale zp
|
// scale zp
|
||||||
if constexpr (!IsDynamicC8) {
|
if constexpr (is_scale_channel_wise) {
|
||||||
if constexpr (is_scale_channel_wise) {
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||||
}
|
|
||||||
} else {
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
|
||||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const int scale_col = (kz * 2 + fz) * 4;
|
#pragma unroll
|
||||||
b_frag_dq_T[0] *= cache_v_scale[scale_col];
|
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||||
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
|
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||||
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
|
}
|
||||||
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
|
|
||||||
b_frag_dq_T[4] *= cache_v_scale[scale_col];
|
|
||||||
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
|
|
||||||
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
|
|
||||||
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
|
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
|
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
|
||||||
@@ -1319,9 +1171,7 @@ template <uint32_t num_frags_x,
|
|||||||
uint32_t block_size,
|
uint32_t block_size,
|
||||||
typename T,
|
typename T,
|
||||||
typename CacheT,
|
typename CacheT,
|
||||||
bool is_scale_channel_wise = false,
|
bool is_scale_channel_wise = false, bool IsFP8=false>
|
||||||
bool IsFP8 = false,
|
|
||||||
bool IsDynamicC8 = false>
|
|
||||||
__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||||
smem_t* v_smem,
|
smem_t* v_smem,
|
||||||
uint32_t* v_smem_offset_r,
|
uint32_t* v_smem_offset_r,
|
||||||
@@ -1365,28 +1215,16 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
|||||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||||
// scale zp
|
// scale zp
|
||||||
if constexpr (!IsDynamicC8) {
|
if constexpr (is_scale_channel_wise) {
|
||||||
if constexpr (is_scale_channel_wise) {
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||||
}
|
|
||||||
} else {
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
|
||||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const int scale_col = (kz * 2 + fz) * 4;
|
#pragma unroll
|
||||||
b_frag_dq_T[0] *= cache_v_scale[scale_col];
|
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||||
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
|
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||||
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
|
}
|
||||||
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
|
|
||||||
b_frag_dq_T[4] *= cache_v_scale[scale_col];
|
|
||||||
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
|
|
||||||
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
|
|
||||||
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
|
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
|
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
|
||||||
@@ -1477,33 +1315,6 @@ __device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8],
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <uint32_t num_frags_x, uint32_t num_frags_y>
|
|
||||||
__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8],
|
|
||||||
float (*d)[2],
|
|
||||||
float (*m)[2],
|
|
||||||
float (*current_sinks)[2]) {
|
|
||||||
float d_rcp[num_frags_x][2];
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t j = 0; j < 2; ++j) {
|
|
||||||
d_rcp[fx][j] = 1.f / (d[fx][j] + __expf(current_sinks[fx][j] - m[fx][j]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
|
|
||||||
o_frag[fx][fy][reg_id] =
|
|
||||||
o_frag[fx][fy][reg_id] * d_rcp[fx][(reg_id % 4) / 2];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <uint32_t num_frags_x,
|
template <uint32_t num_frags_x,
|
||||||
uint32_t num_frags_y,
|
uint32_t num_frags_y,
|
||||||
uint32_t NUM_WARPS,
|
uint32_t NUM_WARPS,
|
||||||
@@ -2317,7 +2128,6 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
|||||||
const int *__restrict__ cu_seqlens_q,
|
const int *__restrict__ cu_seqlens_q,
|
||||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||||
const T *__restrict__ sinks, // [q_num_heads]
|
|
||||||
OutT *__restrict__ out,
|
OutT *__restrict__ out,
|
||||||
const float quant_max_bound,
|
const float quant_max_bound,
|
||||||
const float quant_min_bound,
|
const float quant_min_bound,
|
||||||
@@ -2355,11 +2165,17 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
|||||||
using LoadT = AlignedVector<T, vec_size>;
|
using LoadT = AlignedVector<T, vec_size>;
|
||||||
LoadT load_vec;
|
LoadT load_vec;
|
||||||
LoadT res_vec;
|
LoadT res_vec;
|
||||||
|
if constexpr (std::is_same<T, half>::value) {
|
||||||
for (int i = 0; i < vec_size; ++i) {
|
#pragma unroll
|
||||||
res_vec[i] = T(0.f);
|
for (int i = 0; i < vec_size / 2; ++i) {
|
||||||
|
*((half2 *)(&res_vec) + i) = make_half2(0, 0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < vec_size / 2; ++i) {
|
||||||
|
*((nv_bfloat162 *)(&res_vec) + i) = make_bfloat162(0, 0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float m;
|
float m;
|
||||||
float d = 1.f;
|
float d = 1.f;
|
||||||
if constexpr (std::is_same<T, half>::value) {
|
if constexpr (std::is_same<T, half>::value) {
|
||||||
@@ -2375,7 +2191,8 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
|||||||
const float m_now = multi_m[offset];
|
const float m_now = multi_m[offset];
|
||||||
const float d_now = multi_d[offset];
|
const float d_now = multi_d[offset];
|
||||||
m = max(m_prev, m_now);
|
m = max(m_prev, m_now);
|
||||||
offset = offset * head_dim + vid * vec_size;
|
offset = (bid * num_chunks * num_heads + i * num_heads + hid) * head_dim +
|
||||||
|
vid * vec_size;
|
||||||
Load<T, vec_size>(&multi_out[offset], &load_vec);
|
Load<T, vec_size>(&multi_out[offset], &load_vec);
|
||||||
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
|
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
|
||||||
const T scale1_T = static_cast<T>(scale1),
|
const T scale1_T = static_cast<T>(scale1),
|
||||||
@@ -2401,12 +2218,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
|||||||
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
|
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
|
||||||
st.merge(load_vec, m_tmp, d_tmp);
|
st.merge(load_vec, m_tmp, d_tmp);
|
||||||
}
|
}
|
||||||
if (sinks) {
|
st.normalize();
|
||||||
float current_sink = static_cast<float>(sinks[hid]);
|
|
||||||
st.normalize(current_sink);
|
|
||||||
} else {
|
|
||||||
st.normalize();
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size;
|
const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size;
|
||||||
AlignedVector<T, vec_size> shift_bias_vec;
|
AlignedVector<T, vec_size> shift_bias_vec;
|
||||||
@@ -2446,7 +2258,6 @@ __global__ void merge_multi_chunks_v2_kernel(
|
|||||||
const int *__restrict__ cu_seqlens_q,
|
const int *__restrict__ cu_seqlens_q,
|
||||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||||
const T *__restrict__ sinks, // [q_num_heads]
|
|
||||||
OutT *__restrict__ out,
|
OutT *__restrict__ out,
|
||||||
const float quant_max_bound,
|
const float quant_max_bound,
|
||||||
const float quant_min_bound,
|
const float quant_min_bound,
|
||||||
@@ -2464,9 +2275,6 @@ __global__ void merge_multi_chunks_v2_kernel(
|
|||||||
__shared__ float md_smem[bdy * 2];
|
__shared__ float md_smem[bdy * 2];
|
||||||
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
||||||
const uint32_t bid = batch_id_per_token[qid];
|
const uint32_t bid = batch_id_per_token[qid];
|
||||||
if(bid == -1){
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
|
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
|
||||||
const int seq_len_q = seq_lens_q[bid];
|
const int seq_len_q = seq_lens_q[bid];
|
||||||
if (seq_len_q == 0) continue;
|
if (seq_len_q == 0) continue;
|
||||||
@@ -2486,8 +2294,6 @@ __global__ void merge_multi_chunks_v2_kernel(
|
|||||||
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
|
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
|
||||||
if (num_chunks_this_seq <= 1) {
|
if (num_chunks_this_seq <= 1) {
|
||||||
continue;
|
continue;
|
||||||
}else if (!ENABLE_PREFILL){
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
using LoadT = AlignedVector<T, vec_size>;
|
using LoadT = AlignedVector<T, vec_size>;
|
||||||
@@ -2564,13 +2370,7 @@ __global__ void merge_multi_chunks_v2_kernel(
|
|||||||
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
|
const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1];
|
||||||
st.merge(load_vec, m_tmp, d_tmp);
|
st.merge(load_vec, m_tmp, d_tmp);
|
||||||
}
|
}
|
||||||
|
st.normalize();
|
||||||
if (sinks) {
|
|
||||||
float current_sink = static_cast<float>(sinks[hid]);
|
|
||||||
st.normalize(current_sink);
|
|
||||||
} else {
|
|
||||||
st.normalize();
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size;
|
const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size;
|
||||||
AlignedVector<T, vec_size> shift_bias_vec;
|
AlignedVector<T, vec_size> shift_bias_vec;
|
||||||
|
|||||||
@@ -15,9 +15,141 @@
|
|||||||
|
|
||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
#include "utils.cuh"
|
#include "utils.cuh"
|
||||||
#include "append_attention_c16_impl.cuh"
|
|
||||||
#include "append_attention_c8_impl.cuh"
|
template <typename T, typename OutT>
|
||||||
#include "append_attention_c4_impl.cuh"
|
void CascadeAppendAttentionC16Kernel(
|
||||||
|
const AppendAttnMetaData& meta_data,
|
||||||
|
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim]
|
||||||
|
const paddle::Tensor&
|
||||||
|
cache_k, // [max_block_num, num_heads, block_size, head_dim]
|
||||||
|
const paddle::Tensor&
|
||||||
|
cache_v, // [max_block_num, num_heads, head_dim, block_size]
|
||||||
|
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_k_scale, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_v_scale, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_k_zp, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_v_zp, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
shift_bias, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
smooth_weight, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::Tensor& seq_lens_q,
|
||||||
|
const paddle::Tensor& seq_lens_kv,
|
||||||
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& batch_id_per_token,
|
||||||
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
|
const paddle::Tensor& block_table,
|
||||||
|
const paddle::Tensor& batch_ids,
|
||||||
|
const paddle::Tensor& tile_ids_per_batch,
|
||||||
|
const int num_blocks,
|
||||||
|
const int block_shape_q,
|
||||||
|
const int max_seq_len,
|
||||||
|
const int max_dec_len,
|
||||||
|
const float quant_max_bound,
|
||||||
|
const float quant_min_bound,
|
||||||
|
const float in_scale,
|
||||||
|
const int max_partition_size,
|
||||||
|
const int encoder_max_partition_size,
|
||||||
|
const int speculate_max_draft_token_num,
|
||||||
|
const bool causal,
|
||||||
|
const bool is_decoder,
|
||||||
|
const bool enable_prefill,
|
||||||
|
cudaStream_t& stream,
|
||||||
|
paddle::Tensor* out);
|
||||||
|
|
||||||
|
template <typename T, typename OutT, bool IsFP8 = false>
|
||||||
|
void CascadeAppendAttentionC8Kernel(
|
||||||
|
const AppendAttnMetaData& meta_data,
|
||||||
|
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim]
|
||||||
|
const paddle::Tensor&
|
||||||
|
cache_k, // [max_block_num, num_heads, block_size, head_dim]
|
||||||
|
const paddle::Tensor&
|
||||||
|
cache_v, // [max_block_num, num_heads, head_dim, block_size]
|
||||||
|
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_k_scale, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_v_scale, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_k_zp, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_v_zp, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
shift_bias, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
smooth_weight, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::Tensor& seq_lens_q,
|
||||||
|
const paddle::Tensor& seq_lens_kv,
|
||||||
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& batch_id_per_token,
|
||||||
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
|
const paddle::Tensor& block_table,
|
||||||
|
const paddle::Tensor& batch_ids,
|
||||||
|
const paddle::Tensor& tile_ids_per_batch,
|
||||||
|
const int num_blocks,
|
||||||
|
const int block_shape_q,
|
||||||
|
const int max_seq_len,
|
||||||
|
const int max_dec_len,
|
||||||
|
const float quant_max_bound,
|
||||||
|
const float quant_min_bound,
|
||||||
|
const float in_scale,
|
||||||
|
const int max_partition_size,
|
||||||
|
const int encoder_max_partition_size,
|
||||||
|
const int speculate_max_draft_token_num,
|
||||||
|
const bool causal,
|
||||||
|
const bool is_decoder,
|
||||||
|
const bool enable_prefill,
|
||||||
|
cudaStream_t& stream,
|
||||||
|
paddle::Tensor* out);
|
||||||
|
|
||||||
|
template <typename T, typename OutT>
|
||||||
|
void CascadeAppendAttentionC4Kernel(
|
||||||
|
const AppendAttnMetaData& meta_data,
|
||||||
|
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim]
|
||||||
|
const paddle::Tensor&
|
||||||
|
cache_k, // [max_block_num, num_heads, block_size, head_dim]
|
||||||
|
const paddle::Tensor&
|
||||||
|
cache_v, // [max_block_num, num_heads, head_dim, block_size]
|
||||||
|
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_k_scale, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_v_scale, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_k_zp, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_v_zp, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
shift_bias, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
smooth_weight, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::Tensor& seq_lens_q,
|
||||||
|
const paddle::Tensor& seq_lens_kv,
|
||||||
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& batch_id_per_token,
|
||||||
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
|
const paddle::Tensor& block_table,
|
||||||
|
const paddle::Tensor& batch_ids,
|
||||||
|
const paddle::Tensor& tile_ids_per_batch,
|
||||||
|
const int num_blocks,
|
||||||
|
const int block_shape_q,
|
||||||
|
const int max_seq_len,
|
||||||
|
const int max_dec_len,
|
||||||
|
const float quant_max_bound,
|
||||||
|
const float quant_min_bound,
|
||||||
|
const float in_scale,
|
||||||
|
const int max_partition_size,
|
||||||
|
const int encoder_max_partition_size,
|
||||||
|
const int speculate_max_draft_token_num,
|
||||||
|
const bool causal,
|
||||||
|
const bool is_decoder,
|
||||||
|
const bool enable_prefill,
|
||||||
|
cudaStream_t& stream,
|
||||||
|
paddle::Tensor* out);
|
||||||
|
|
||||||
template <typename T, typename OutT>
|
template <typename T, typename OutT>
|
||||||
void CascadeAppendAttentionKernel(
|
void CascadeAppendAttentionKernel(
|
||||||
@@ -40,8 +172,6 @@ void CascadeAppendAttentionKernel(
|
|||||||
shift_bias, // [num_kv_heads, head_dim]
|
shift_bias, // [num_kv_heads, head_dim]
|
||||||
const paddle::optional<paddle::Tensor>&
|
const paddle::optional<paddle::Tensor>&
|
||||||
smooth_weight, // [num_kv_heads, head_dim]
|
smooth_weight, // [num_kv_heads, head_dim]
|
||||||
const paddle::optional<paddle::Tensor>&
|
|
||||||
sinks, // [num_heads]
|
|
||||||
const paddle::Tensor& seq_lens_q,
|
const paddle::Tensor& seq_lens_q,
|
||||||
const paddle::Tensor& seq_lens_kv,
|
const paddle::Tensor& seq_lens_kv,
|
||||||
const paddle::Tensor& seq_lens_encoder,
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
@@ -65,8 +195,7 @@ void CascadeAppendAttentionKernel(
|
|||||||
const bool is_decoder,
|
const bool is_decoder,
|
||||||
const bool enable_prefill,
|
const bool enable_prefill,
|
||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* out,
|
paddle::Tensor* out) {
|
||||||
const int sliding_window) {
|
|
||||||
if (cache_quant_type_str == "none") {
|
if (cache_quant_type_str == "none") {
|
||||||
CascadeAppendAttentionC16Kernel<T, OutT>(meta_data,
|
CascadeAppendAttentionC16Kernel<T, OutT>(meta_data,
|
||||||
qkv,
|
qkv,
|
||||||
@@ -79,7 +208,6 @@ void CascadeAppendAttentionKernel(
|
|||||||
cache_v_zp,
|
cache_v_zp,
|
||||||
shift_bias,
|
shift_bias,
|
||||||
smooth_weight,
|
smooth_weight,
|
||||||
sinks,
|
|
||||||
seq_lens_q,
|
seq_lens_q,
|
||||||
seq_lens_kv,
|
seq_lens_kv,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -102,10 +230,9 @@ void CascadeAppendAttentionKernel(
|
|||||||
is_decoder,
|
is_decoder,
|
||||||
enable_prefill,
|
enable_prefill,
|
||||||
stream,
|
stream,
|
||||||
out,
|
out);
|
||||||
sliding_window);
|
|
||||||
} else if (cache_quant_type_str == "cache_int8") {
|
} else if (cache_quant_type_str == "cache_int8") {
|
||||||
CascadeAppendAttentionC8Kernel<T, OutT, false>(meta_data,
|
CascadeAppendAttentionC8Kernel<T, OutT>(meta_data,
|
||||||
qkv,
|
qkv,
|
||||||
cache_k,
|
cache_k,
|
||||||
cache_v,
|
cache_v,
|
||||||
@@ -116,7 +243,6 @@ void CascadeAppendAttentionKernel(
|
|||||||
cache_v_zp,
|
cache_v_zp,
|
||||||
shift_bias,
|
shift_bias,
|
||||||
smooth_weight,
|
smooth_weight,
|
||||||
sinks,
|
|
||||||
seq_lens_q,
|
seq_lens_q,
|
||||||
seq_lens_kv,
|
seq_lens_kv,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -138,11 +264,9 @@ void CascadeAppendAttentionKernel(
|
|||||||
causal,
|
causal,
|
||||||
is_decoder,
|
is_decoder,
|
||||||
enable_prefill,
|
enable_prefill,
|
||||||
cache_quant_type_str,
|
|
||||||
stream,
|
stream,
|
||||||
out,
|
out);
|
||||||
sliding_window);
|
} else if (cache_quant_type_str == "cache_fp8") {
|
||||||
} else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
|
|
||||||
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
|
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
|
||||||
qkv,
|
qkv,
|
||||||
cache_k,
|
cache_k,
|
||||||
@@ -154,7 +278,6 @@ void CascadeAppendAttentionKernel(
|
|||||||
cache_v_zp,
|
cache_v_zp,
|
||||||
shift_bias,
|
shift_bias,
|
||||||
smooth_weight,
|
smooth_weight,
|
||||||
sinks,
|
|
||||||
seq_lens_q,
|
seq_lens_q,
|
||||||
seq_lens_kv,
|
seq_lens_kv,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -176,10 +299,8 @@ void CascadeAppendAttentionKernel(
|
|||||||
causal,
|
causal,
|
||||||
is_decoder,
|
is_decoder,
|
||||||
enable_prefill,
|
enable_prefill,
|
||||||
cache_quant_type_str,
|
|
||||||
stream,
|
stream,
|
||||||
out,
|
out);
|
||||||
sliding_window);
|
|
||||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||||
CascadeAppendAttentionC4Kernel<T, OutT>(meta_data,
|
CascadeAppendAttentionC4Kernel<T, OutT>(meta_data,
|
||||||
qkv,
|
qkv,
|
||||||
@@ -192,7 +313,6 @@ void CascadeAppendAttentionKernel(
|
|||||||
cache_v_zp,
|
cache_v_zp,
|
||||||
shift_bias,
|
shift_bias,
|
||||||
smooth_weight,
|
smooth_weight,
|
||||||
sinks,
|
|
||||||
seq_lens_q,
|
seq_lens_q,
|
||||||
seq_lens_kv,
|
seq_lens_kv,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -215,8 +335,7 @@ void CascadeAppendAttentionKernel(
|
|||||||
is_decoder,
|
is_decoder,
|
||||||
enable_prefill,
|
enable_prefill,
|
||||||
stream,
|
stream,
|
||||||
out,
|
out);
|
||||||
sliding_window);
|
|
||||||
} else {
|
} else {
|
||||||
PD_THROW(
|
PD_THROW(
|
||||||
"cache_quant_type_str should be one of [none, cache_int8, "
|
"cache_quant_type_str should be one of [none, cache_int8, "
|
||||||
|
|||||||
@@ -1,243 +0,0 @@
|
|||||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Universal template instantiation generator - fully based on configuration file template instantiation generation."""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TemplateConfig:
|
|
||||||
"""Template configuration class."""
|
|
||||||
|
|
||||||
name: str # Function name
|
|
||||||
function_name: str # Actual function name
|
|
||||||
impl_file: str # Implementation file path
|
|
||||||
template_params: List[str] # Template parameter list (in order)
|
|
||||||
dispatch_params: Dict[str, List[Any]] # Dispatch parameters
|
|
||||||
data_types: Optional[List[Tuple[str, str, str]]] = None # Data type combinations (input_type, output_type, suffix)
|
|
||||||
max_instances_per_file: int = 60 # Maximum instances per file
|
|
||||||
file_prefix: str = "" # File prefix
|
|
||||||
function_signature: str = "" # Function signature template
|
|
||||||
|
|
||||||
|
|
||||||
class UniversalTemplateInstantiator:
|
|
||||||
"""Universal template instantiator - fully based on configuration file."""
|
|
||||||
|
|
||||||
def __init__(self, config_file: str):
|
|
||||||
"""Initialize the instantiator."""
|
|
||||||
self.config_file = config_file
|
|
||||||
self.configs = self._load_configs()
|
|
||||||
|
|
||||||
def _load_configs(self) -> Dict[str, TemplateConfig]:
|
|
||||||
"""Load configuration file."""
|
|
||||||
with open(self.config_file, "r", encoding="utf-8") as f:
|
|
||||||
config_data = json.load(f)
|
|
||||||
|
|
||||||
configs = {}
|
|
||||||
for name, config_dict in config_data.items():
|
|
||||||
config = TemplateConfig(**config_dict)
|
|
||||||
self._validate_config(config)
|
|
||||||
configs[name] = config
|
|
||||||
return configs
|
|
||||||
|
|
||||||
def _validate_config(self, config: TemplateConfig):
|
|
||||||
"""Validate configuration completeness."""
|
|
||||||
has_t = "T" in config.template_params
|
|
||||||
has_out_t = "OutT" in config.template_params
|
|
||||||
|
|
||||||
if (has_t or has_out_t) and not config.data_types:
|
|
||||||
raise ValueError(
|
|
||||||
f"Configuration '{config.name}' has T or OutT in template_params but no data_types configured"
|
|
||||||
)
|
|
||||||
|
|
||||||
special_params = {"T", "OutT", "NUM_WARP_Q"}
|
|
||||||
for param_name in config.template_params:
|
|
||||||
if param_name not in special_params and param_name not in config.dispatch_params:
|
|
||||||
raise ValueError(f"Template parameter '{param_name}' in '{config.name}' not found in dispatch_params")
|
|
||||||
|
|
||||||
if "NUM_WARP_Q" in config.template_params and "BLOCK_SHAPE_Q" not in config.dispatch_params:
|
|
||||||
raise ValueError(
|
|
||||||
f"Template parameter 'NUM_WARP_Q' in '{config.name}' requires 'BLOCK_SHAPE_Q' in dispatch_params"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _calculate_num_warp_q(self, block_shape_q: int) -> int:
|
|
||||||
"""Calculate number of warps."""
|
|
||||||
if block_shape_q <= 32:
|
|
||||||
return 1
|
|
||||||
else:
|
|
||||||
return 4
|
|
||||||
|
|
||||||
def _build_template_args(self, config: TemplateConfig, t_in: str, t_out: str, params: Dict[str, Any]) -> str:
|
|
||||||
"""Build template arguments."""
|
|
||||||
template_args_parts = []
|
|
||||||
|
|
||||||
for param_name in config.template_params:
|
|
||||||
if param_name == "T":
|
|
||||||
if t_in:
|
|
||||||
template_args_parts.append(t_in)
|
|
||||||
else:
|
|
||||||
raise ValueError("Template parameter 'T' requires input type, but data_types is empty or invalid")
|
|
||||||
elif param_name == "OutT":
|
|
||||||
if t_out:
|
|
||||||
template_args_parts.append(t_out)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Template parameter 'OutT' requires output type, but data_types is empty or invalid"
|
|
||||||
)
|
|
||||||
elif param_name == "NUM_WARP_Q":
|
|
||||||
if "BLOCK_SHAPE_Q" in params:
|
|
||||||
num_warp_q = self._calculate_num_warp_q(params["BLOCK_SHAPE_Q"])
|
|
||||||
template_args_parts.append(str(num_warp_q))
|
|
||||||
else:
|
|
||||||
raise ValueError("Template parameter 'NUM_WARP_Q' requires 'BLOCK_SHAPE_Q' in dispatch_params")
|
|
||||||
elif param_name in params:
|
|
||||||
template_args_parts.append(str(params[param_name]))
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Template parameter '{param_name}' not found in dispatch_params")
|
|
||||||
|
|
||||||
return f"<{', '.join(template_args_parts)}>"
|
|
||||||
|
|
||||||
def _generate_function_signature(self, config: TemplateConfig, template_args: str) -> str:
|
|
||||||
"""Generate function signature."""
|
|
||||||
if config.function_signature:
|
|
||||||
return config.function_signature.format(function_name=config.function_name, template_args=template_args)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Function signature not found for {config.name}")
|
|
||||||
|
|
||||||
def _generate_file_header(self, config: TemplateConfig) -> str:
|
|
||||||
"""Generate file header."""
|
|
||||||
return f"""// Generated by autogen_template_instantiation.py - Do not edit.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "../../{config.impl_file}"
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _generate_template_instantiation(
|
|
||||||
self, config: TemplateConfig, t_in: str, t_out: str, params: Dict[str, Any]
|
|
||||||
) -> str:
|
|
||||||
"""Generate template instantiation."""
|
|
||||||
template_args = self._build_template_args(config, t_in, t_out, params)
|
|
||||||
return self._generate_function_signature(config, template_args)
|
|
||||||
|
|
||||||
def generate_combinations_for_type(self, config: TemplateConfig, t_in: str, t_out: str) -> List[Dict[str, Any]]:
|
|
||||||
"""Generate parameter combinations for specific type."""
|
|
||||||
combinations = []
|
|
||||||
|
|
||||||
def _generate_recursive(
|
|
||||||
params_dict: Dict[str, List[Any]], current_params: Dict[str, Any], param_names: List[str]
|
|
||||||
):
|
|
||||||
if not param_names:
|
|
||||||
combinations.append(current_params.copy())
|
|
||||||
return
|
|
||||||
|
|
||||||
param_name = param_names[0]
|
|
||||||
for value in params_dict[param_name]:
|
|
||||||
current_params[param_name] = value
|
|
||||||
_generate_recursive(params_dict, current_params, param_names[1:])
|
|
||||||
|
|
||||||
_generate_recursive(config.dispatch_params, {}, list(config.dispatch_params.keys()))
|
|
||||||
return combinations
|
|
||||||
|
|
||||||
def split_combinations(self, combinations: List[Dict[str, Any]], max_per_file: int) -> List[List[Dict[str, Any]]]:
|
|
||||||
"""Split combinations into multiple files."""
|
|
||||||
chunks = []
|
|
||||||
for i in range(0, len(combinations), max_per_file):
|
|
||||||
chunk = combinations[i : i + max_per_file]
|
|
||||||
chunks.append(chunk)
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
def generate_file_content(
|
|
||||||
self,
|
|
||||||
config: TemplateConfig,
|
|
||||||
t_in: str,
|
|
||||||
t_out: str,
|
|
||||||
t_out_name: str,
|
|
||||||
file_index: int,
|
|
||||||
combinations: List[Dict[str, Any]],
|
|
||||||
) -> str:
|
|
||||||
"""Generate file content."""
|
|
||||||
content = self._generate_file_header(config)
|
|
||||||
|
|
||||||
for params in combinations:
|
|
||||||
content += self._generate_template_instantiation(config, t_in, t_out, params)
|
|
||||||
|
|
||||||
return content
|
|
||||||
|
|
||||||
def generate_for_function_type(self, function_name: str, output_dir: str):
|
|
||||||
"""Generate template instantiation files for specific function type."""
|
|
||||||
if function_name not in self.configs:
|
|
||||||
raise ValueError(f"Function type '{function_name}' not found in config")
|
|
||||||
|
|
||||||
config = self.configs[function_name]
|
|
||||||
output_path = Path(output_dir)
|
|
||||||
output_path.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
if not config.data_types:
|
|
||||||
data_types = [("", "", "")]
|
|
||||||
else:
|
|
||||||
data_types = config.data_types
|
|
||||||
|
|
||||||
for t_in, t_out, t_out_name in data_types:
|
|
||||||
combinations = self.generate_combinations_for_type(config, t_in, t_out)
|
|
||||||
if combinations:
|
|
||||||
chunks = self.split_combinations(combinations, config.max_instances_per_file)
|
|
||||||
for i, chunk in enumerate(chunks):
|
|
||||||
filename = f"{config.file_prefix}{t_out_name}_part_{i:02d}.cu"
|
|
||||||
filepath = output_path / filename
|
|
||||||
content = self.generate_file_content(config, t_in, t_out, t_out_name, i, chunk)
|
|
||||||
with open(filepath, "w", encoding="utf-8") as f:
|
|
||||||
f.write(content)
|
|
||||||
|
|
||||||
def generate_all(self, output_dir: str):
|
|
||||||
"""Generate all configured function types."""
|
|
||||||
for function_name in self.configs.keys():
|
|
||||||
print(f"Generating template instantiations for {function_name}...")
|
|
||||||
self.generate_for_function_type(function_name, output_dir)
|
|
||||||
print(f"Completed generating {function_name} template instantiations.")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main function."""
|
|
||||||
parser = argparse.ArgumentParser(description="Universal template instantiation generator")
|
|
||||||
parser.add_argument(
|
|
||||||
"--config",
|
|
||||||
"-c",
|
|
||||||
type=str,
|
|
||||||
default="gpu_ops/append_attn/template_config.json",
|
|
||||||
help="Configuration file path (JSON format)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output",
|
|
||||||
"-o",
|
|
||||||
type=str,
|
|
||||||
default="gpu_ops/append_attn/template_instantiation/autogen",
|
|
||||||
help="Output directory",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
|
||||||
instantiator = UniversalTemplateInstantiator(args.config)
|
|
||||||
instantiator.generate_all(args.output)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -13,8 +13,8 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "helper.h"
|
|
||||||
#include "utils.cuh"
|
#include "multi_head_latent_attention_kernel.h"
|
||||||
|
|
||||||
template <size_t vec_size, typename T>
|
template <size_t vec_size, typename T>
|
||||||
struct softmax_state_t {
|
struct softmax_state_t {
|
||||||
|
|||||||
@@ -11,10 +11,8 @@
|
|||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "decode_attention_func.cuh"
|
#include "decode_attention_func.cuh"
|
||||||
#include "multiquery_decoder_attention_kernel.h"
|
|
||||||
|
|
||||||
#define CHECK(call) \
|
#define CHECK(call) \
|
||||||
do \
|
do \
|
||||||
@@ -473,3 +471,90 @@ void MultiQueryDecoderAttention(
|
|||||||
// CHECK(cudaGetLastError());
|
// CHECK(cudaGetLastError());
|
||||||
// CHECK(cudaDeviceSynchronize());
|
// CHECK(cudaDeviceSynchronize());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void DecodeMLAAttentionKernel(
|
||||||
|
const AppendAttnMetaData& meta_data,
|
||||||
|
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||||
|
const paddle::Tensor &cache_k,
|
||||||
|
const paddle::Tensor &cache_v,
|
||||||
|
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||||
|
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||||
|
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||||
|
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||||
|
const paddle::Tensor &seq_lens_kv,
|
||||||
|
const paddle::Tensor &batch_id_per_token,
|
||||||
|
const paddle::Tensor &cu_seqlens_q,
|
||||||
|
const paddle::Tensor &block_table,
|
||||||
|
int max_seq_len,
|
||||||
|
int max_dec_len,
|
||||||
|
float softmax_scale,
|
||||||
|
float in_scale,
|
||||||
|
bool causal,
|
||||||
|
cudaStream_t &stream,
|
||||||
|
paddle::Tensor *out) {
|
||||||
|
const auto token_num = meta_data.token_nums;
|
||||||
|
const auto block_size = meta_data.block_size;
|
||||||
|
const auto bsz = meta_data.batch_size;
|
||||||
|
const auto num_heads = meta_data.q_num_heads;
|
||||||
|
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
|
||||||
|
const auto head_dim_qk = meta_data.head_dims;
|
||||||
|
const auto head_dim_v = meta_data.head_dims_v;
|
||||||
|
const float rope_scale = 0.0;
|
||||||
|
const float rope_theta = 0.0;
|
||||||
|
const uint32_t deal_each_time = get_cascade_attention_deal_each_time();
|
||||||
|
const uint32_t num_stage = get_cascade_attention_num_stages();
|
||||||
|
const uint32_t num_threads = get_cascade_attention_num_threads();
|
||||||
|
|
||||||
|
DISPATCH_CAUSAL(causal, CAUSAL,
|
||||||
|
{DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE,
|
||||||
|
{DISPATCH_MLA_HEAD_DIM(head_dim_qk, HEAD_DIM_QK,
|
||||||
|
{DISPATCH_MLA_HEAD_DIM(head_dim_v, HEAD_DIM_V,
|
||||||
|
{DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE,
|
||||||
|
{DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME,
|
||||||
|
{MultiQueryDecoderAttention<T, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, CAUSAL, 2, 16, DEAL_EACH_TIME>(
|
||||||
|
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, batch_id_per_token, cu_seqlens_q,
|
||||||
|
block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})});
|
||||||
|
}
|
||||||
|
|
||||||
|
template void DecodeMLAAttentionKernel<paddle::bfloat16>(
|
||||||
|
const AppendAttnMetaData& meta_data,
|
||||||
|
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||||
|
const paddle::Tensor &cache_k,
|
||||||
|
const paddle::Tensor &cache_v,
|
||||||
|
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||||
|
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||||
|
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||||
|
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||||
|
const paddle::Tensor &seq_lens_kv,
|
||||||
|
const paddle::Tensor &batch_id_per_token,
|
||||||
|
const paddle::Tensor &cu_seqlens_q,
|
||||||
|
const paddle::Tensor &block_table,
|
||||||
|
int max_seq_len,
|
||||||
|
int max_dec_len,
|
||||||
|
float softmax_scale,
|
||||||
|
float in_scale,
|
||||||
|
bool causal,
|
||||||
|
cudaStream_t &stream,
|
||||||
|
paddle::Tensor *out);
|
||||||
|
|
||||||
|
template void DecodeMLAAttentionKernel<paddle::float16>(
|
||||||
|
const AppendAttnMetaData& meta_data,
|
||||||
|
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||||
|
const paddle::Tensor &cache_k,
|
||||||
|
const paddle::Tensor &cache_v,
|
||||||
|
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||||
|
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||||
|
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||||
|
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||||
|
const paddle::Tensor &seq_lens_kv,
|
||||||
|
const paddle::Tensor &batch_id_per_token,
|
||||||
|
const paddle::Tensor &cu_seqlens_q,
|
||||||
|
const paddle::Tensor &block_table,
|
||||||
|
int max_seq_len,
|
||||||
|
int max_dec_len,
|
||||||
|
float softmax_scale,
|
||||||
|
float in_scale,
|
||||||
|
bool causal,
|
||||||
|
cudaStream_t &stream,
|
||||||
|
paddle::Tensor *out);
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -17,30 +17,31 @@
|
|||||||
|
|
||||||
template <typename T, typename QKV_TYPE>
|
template <typename T, typename QKV_TYPE>
|
||||||
void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||||
T* key_cache,
|
T* key_cache,
|
||||||
T* value_cache,
|
T* value_cache,
|
||||||
T* qkv_out,
|
T* qkv_out,
|
||||||
const int* block_tables,
|
const int* block_tables,
|
||||||
const int* cu_seqlens_q,
|
const int* batch_id_per_token,
|
||||||
const int* seq_lens,
|
const int* cu_seqlens_q,
|
||||||
const int* seq_lens_encoder,
|
const int* seq_lens,
|
||||||
const float* cos_emb,
|
const int* seq_lens_encoder,
|
||||||
const float* sin_emb,
|
const float* cos_emb,
|
||||||
const float* qkv_out_scales,
|
const float* sin_emb,
|
||||||
const T* qkv_biases,
|
const float* qkv_out_scales,
|
||||||
const int max_seq_len,
|
const T* qkv_biases,
|
||||||
const int max_blocks_per_seq,
|
const int max_seq_len,
|
||||||
const int num_heads,
|
const int max_blocks_per_seq,
|
||||||
const int kv_num_heads,
|
const int num_heads,
|
||||||
const int dim_head,
|
const int kv_num_heads,
|
||||||
const int block_size,
|
const int dim_head,
|
||||||
const int bsz,
|
const int block_size,
|
||||||
const cudaStream_t& stream,
|
const int bsz,
|
||||||
const bool use_neox_style,
|
const cudaStream_t& stream,
|
||||||
const bool rope_3d,
|
const bool use_neox_style,
|
||||||
const float* q_norm_weight,
|
const bool rope_3d,
|
||||||
const float* k_norm_weight,
|
const float* q_norm_weight,
|
||||||
const float rms_norm_eps) {
|
const float* k_norm_weight,
|
||||||
|
const float rms_norm_eps) {
|
||||||
const uint32_t elem_nums =
|
const uint32_t elem_nums =
|
||||||
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||||
: bsz * (num_heads + 2 * kv_num_heads) * dim_head;
|
: bsz * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||||
@@ -58,6 +59,7 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
|||||||
value_cache,
|
value_cache,
|
||||||
qkv_out,
|
qkv_out,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -82,6 +84,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
|||||||
T* value_cache,
|
T* value_cache,
|
||||||
T* qkv_out,
|
T* qkv_out,
|
||||||
const int* block_tables,
|
const int* block_tables,
|
||||||
|
const int* batch_id_per_token,
|
||||||
const int* cu_seqlens_q,
|
const int* cu_seqlens_q,
|
||||||
const int* seq_lens,
|
const int* seq_lens,
|
||||||
const int* seq_lens_encoder,
|
const int* seq_lens_encoder,
|
||||||
@@ -118,6 +121,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
|||||||
value_cache,
|
value_cache,
|
||||||
qkv_out,
|
qkv_out,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -134,49 +138,48 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
|||||||
kv_num_heads,
|
kv_num_heads,
|
||||||
rope_3d);
|
rope_3d);
|
||||||
} else {
|
} else {
|
||||||
if (rotary_dim < dim_head) {
|
if (rotary_dim < dim_head){
|
||||||
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>
|
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>
|
||||||
<<<grid_size, blocksize, 0, stream>>>(
|
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||||
reinterpret_cast<const T*>(qkv),
|
key_cache,
|
||||||
key_cache,
|
value_cache,
|
||||||
value_cache,
|
qkv_out,
|
||||||
qkv_out,
|
block_tables,
|
||||||
block_tables,
|
cu_seqlens_q,
|
||||||
cu_seqlens_q,
|
seq_lens,
|
||||||
seq_lens,
|
seq_lens_encoder,
|
||||||
seq_lens_encoder,
|
cos_emb,
|
||||||
cos_emb,
|
sin_emb,
|
||||||
sin_emb,
|
max_seq_len,
|
||||||
max_seq_len,
|
max_blocks_per_seq,
|
||||||
max_blocks_per_seq,
|
num_heads,
|
||||||
num_heads,
|
dim_head,
|
||||||
dim_head,
|
rotary_dim,
|
||||||
rotary_dim,
|
block_size,
|
||||||
block_size,
|
elem_nums,
|
||||||
elem_nums,
|
kv_num_heads,
|
||||||
kv_num_heads,
|
rope_3d);
|
||||||
rope_3d);
|
}else{
|
||||||
} else {
|
|
||||||
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
||||||
<<<grid_size, blocksize, 0, stream>>>(
|
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||||
reinterpret_cast<const T*>(qkv),
|
key_cache,
|
||||||
key_cache,
|
value_cache,
|
||||||
value_cache,
|
qkv_out,
|
||||||
qkv_out,
|
block_tables,
|
||||||
block_tables,
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
cos_emb,
|
cos_emb,
|
||||||
sin_emb,
|
sin_emb,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_blocks_per_seq,
|
max_blocks_per_seq,
|
||||||
num_heads,
|
num_heads,
|
||||||
dim_head,
|
dim_head,
|
||||||
block_size,
|
block_size,
|
||||||
elem_nums,
|
elem_nums,
|
||||||
kv_num_heads,
|
kv_num_heads,
|
||||||
rope_3d);
|
rope_3d);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -188,6 +191,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
|||||||
value_cache,
|
value_cache,
|
||||||
qkv_out,
|
qkv_out,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -210,6 +214,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
|||||||
value_cache,
|
value_cache,
|
||||||
qkv_out,
|
qkv_out,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -227,15 +232,13 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T,
|
template <typename T, typename QKV_TYPE, bool is_scale_channel_wise = false, bool IsFP8=false>
|
||||||
typename QKV_TYPE,
|
|
||||||
bool is_scale_channel_wise = false,
|
|
||||||
bool IsFP8 = false>
|
|
||||||
void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||||
uint8_t* key_cache,
|
uint8_t* key_cache,
|
||||||
uint8_t* value_cache,
|
uint8_t* value_cache,
|
||||||
T* qkv_out,
|
T* qkv_out,
|
||||||
const int* block_tables,
|
const int* block_tables,
|
||||||
|
const int* batch_id_per_token,
|
||||||
const int* cu_seqlens_q,
|
const int* cu_seqlens_q,
|
||||||
const int* seq_lens,
|
const int* seq_lens,
|
||||||
const int* seq_lens_encoder,
|
const int* seq_lens_encoder,
|
||||||
@@ -268,6 +271,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
|||||||
value_cache,
|
value_cache,
|
||||||
qkv_out,
|
qkv_out,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -293,6 +297,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
|||||||
value_cache,
|
value_cache,
|
||||||
qkv_out,
|
qkv_out,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -311,18 +316,14 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (qkv_out_scales) {
|
if (qkv_out_scales) {
|
||||||
append_decode_cache_int8_rope_kernel<T,
|
append_decode_cache_int8_rope_kernel<T, 4, 0, 128, is_scale_channel_wise, IsFP8>
|
||||||
4,
|
|
||||||
0,
|
|
||||||
128,
|
|
||||||
is_scale_channel_wise,
|
|
||||||
IsFP8>
|
|
||||||
<<<grids, num_warps * 32, 0, stream>>>(
|
<<<grids, num_warps * 32, 0, stream>>>(
|
||||||
reinterpret_cast<const int*>(qkv),
|
reinterpret_cast<const int*>(qkv),
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
qkv_out,
|
qkv_out,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -341,18 +342,14 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
|||||||
kv_num_heads,
|
kv_num_heads,
|
||||||
rope_3d);
|
rope_3d);
|
||||||
} else {
|
} else {
|
||||||
append_decode_cache_int8_rope_kernel<T,
|
append_decode_cache_int8_rope_kernel<T, 4, 0, 128, is_scale_channel_wise, IsFP8>
|
||||||
4,
|
|
||||||
0,
|
|
||||||
128,
|
|
||||||
is_scale_channel_wise,
|
|
||||||
IsFP8>
|
|
||||||
<<<grids, num_warps * 32, 0, stream>>>(
|
<<<grids, num_warps * 32, 0, stream>>>(
|
||||||
reinterpret_cast<const T*>(qkv),
|
reinterpret_cast<const T*>(qkv),
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
qkv_out,
|
qkv_out,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -378,6 +375,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
|||||||
uint8_t* value_cache,
|
uint8_t* value_cache,
|
||||||
T* qkv_out,
|
T* qkv_out,
|
||||||
const int* block_tables,
|
const int* block_tables,
|
||||||
|
const int* batch_id_per_token,
|
||||||
const int* cu_seqlens_q,
|
const int* cu_seqlens_q,
|
||||||
const int* seq_lens,
|
const int* seq_lens,
|
||||||
const int* seq_lens_encoder,
|
const int* seq_lens_encoder,
|
||||||
@@ -412,6 +410,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
|||||||
value_cache,
|
value_cache,
|
||||||
qkv_out,
|
qkv_out,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -439,6 +438,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
|||||||
value_cache,
|
value_cache,
|
||||||
qkv_out,
|
qkv_out,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -466,6 +466,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
|||||||
value_cache,
|
value_cache,
|
||||||
qkv_out,
|
qkv_out,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -493,6 +494,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
|
|||||||
value_cache,
|
value_cache,
|
||||||
qkv_out,
|
qkv_out,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -519,6 +521,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
const paddle::Tensor& qkv,
|
const paddle::Tensor& qkv,
|
||||||
const paddle::Tensor& seq_lens,
|
const paddle::Tensor& seq_lens,
|
||||||
const paddle::Tensor& seq_lens_encoder,
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& batch_id_per_token,
|
||||||
const paddle::Tensor& cu_seqlens_q,
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||||
@@ -561,15 +564,11 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
use_neox_rotary_style
|
use_neox_rotary_style
|
||||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||||
rotary_dim =
|
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
|
||||||
rotary_embs.get().dims()[rotary_embs.get().dims().size() - 1] * 2;
|
if(rotary_dim < dim_head){
|
||||||
if (rotary_dim < dim_head) {
|
if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight || k_norm_weight|| cache_quant_type_str != "none"){
|
||||||
if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight ||
|
|
||||||
k_norm_weight || cache_quant_type_str != "none") {
|
|
||||||
PADDLE_THROW(phi::errors::Fatal(
|
PADDLE_THROW(phi::errors::Fatal(
|
||||||
"partial_rotary_factor < 1.0 only supports neox_rotary_style=True, "
|
"partial_rotary_factor < 1.0 only supports neox_rotary_style=True, qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and cache_quant_type_str is 'none'."));
|
||||||
"qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and "
|
|
||||||
"cache_quant_type_str is 'none'."));
|
|
||||||
}
|
}
|
||||||
sin_emb = rotary_embs.get().data<float>() + max_seq_len * rotary_dim / 2;
|
sin_emb = rotary_embs.get().data<float>() + max_seq_len * rotary_dim / 2;
|
||||||
}
|
}
|
||||||
@@ -583,6 +582,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||||
block_tables.data<int>(),
|
block_tables.data<int>(),
|
||||||
|
batch_id_per_token.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
seq_lens.data<int>(),
|
seq_lens.data<int>(),
|
||||||
seq_lens_encoder.data<int>(),
|
seq_lens_encoder.data<int>(),
|
||||||
@@ -590,8 +590,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
sin_emb,
|
sin_emb,
|
||||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_blocks_per_seq,
|
max_blocks_per_seq,
|
||||||
num_heads,
|
num_heads,
|
||||||
@@ -605,86 +605,9 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||||
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||||
rms_norm_eps);
|
rms_norm_eps);
|
||||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
|
||||||
constexpr int num_warps = 4;
|
|
||||||
const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) /
|
|
||||||
num_warps * num_warps;
|
|
||||||
dim3 grids(bsz, all_warps / num_warps);
|
|
||||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_,
|
|
||||||
4,
|
|
||||||
0,
|
|
||||||
128,
|
|
||||||
false,
|
|
||||||
true,
|
|
||||||
true>
|
|
||||||
<<<grids, num_warps * 32, 0, stream>>>(
|
|
||||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
|
||||||
key_cache_out->data<uint8_t>(),
|
|
||||||
value_cache_out->data<uint8_t>(),
|
|
||||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
|
||||||
block_tables.data<int>(),
|
|
||||||
cu_seqlens_q.data<int>(),
|
|
||||||
seq_lens.data<int>(),
|
|
||||||
seq_lens_encoder.data<int>(),
|
|
||||||
cos_emb,
|
|
||||||
sin_emb,
|
|
||||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
|
|
||||||
cache_k_scale.get().data<T>())),
|
|
||||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
|
|
||||||
(cache_v_scale.get().data<T>()))),
|
|
||||||
q_norm_weight.get().data<float>(),
|
|
||||||
k_norm_weight.get().data<float>(),
|
|
||||||
max_seq_len,
|
|
||||||
max_blocks_per_seq,
|
|
||||||
num_heads,
|
|
||||||
block_size,
|
|
||||||
127.0f,
|
|
||||||
-127.0f,
|
|
||||||
kv_num_heads,
|
|
||||||
rope_3d,
|
|
||||||
rms_norm_eps);
|
|
||||||
} else if ((cache_quant_type_str == "cache_fp8")) {
|
|
||||||
constexpr int num_warps = 4;
|
|
||||||
const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) /
|
|
||||||
num_warps * num_warps;
|
|
||||||
dim3 grids(bsz, all_warps / num_warps);
|
|
||||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_,
|
|
||||||
4,
|
|
||||||
0,
|
|
||||||
128,
|
|
||||||
false,
|
|
||||||
true,
|
|
||||||
false>
|
|
||||||
<<<grids, num_warps * 32, 0, stream>>>(
|
|
||||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
|
||||||
key_cache_out->data<uint8_t>(),
|
|
||||||
value_cache_out->data<uint8_t>(),
|
|
||||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
|
||||||
block_tables.data<int>(),
|
|
||||||
cu_seqlens_q.data<int>(),
|
|
||||||
seq_lens.data<int>(),
|
|
||||||
seq_lens_encoder.data<int>(),
|
|
||||||
cos_emb,
|
|
||||||
sin_emb,
|
|
||||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
|
|
||||||
cache_k_scale.get().data<T>())),
|
|
||||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
|
|
||||||
(cache_v_scale.get().data<T>()))),
|
|
||||||
q_norm_weight.get().data<float>(),
|
|
||||||
k_norm_weight.get().data<float>(),
|
|
||||||
max_seq_len,
|
|
||||||
max_blocks_per_seq,
|
|
||||||
num_heads,
|
|
||||||
block_size,
|
|
||||||
127.0f,
|
|
||||||
-127.0f,
|
|
||||||
kv_num_heads,
|
|
||||||
rope_3d,
|
|
||||||
rms_norm_eps);
|
|
||||||
} else {
|
} else {
|
||||||
PD_THROW(
|
PD_THROW(
|
||||||
"append_decode_cache_rope_qk_norm just supports cache_quant_type "
|
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||||
"none/block_wise_fp8/cache_fp8");
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (cache_quant_type_str == "none") {
|
if (cache_quant_type_str == "none") {
|
||||||
@@ -694,6 +617,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||||
block_tables.data<int>(),
|
block_tables.data<int>(),
|
||||||
|
batch_id_per_token.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
seq_lens.data<int>(),
|
seq_lens.data<int>(),
|
||||||
seq_lens_encoder.data<int>(),
|
seq_lens_encoder.data<int>(),
|
||||||
@@ -701,8 +625,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
sin_emb,
|
sin_emb,
|
||||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_blocks_per_seq,
|
max_blocks_per_seq,
|
||||||
num_heads,
|
num_heads,
|
||||||
@@ -716,82 +640,17 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
rope_3d);
|
rope_3d);
|
||||||
} else if (cache_quant_type_str == "cache_int8") {
|
} else if (cache_quant_type_str == "cache_int8") {
|
||||||
bool is_scale_channel_wise = false;
|
bool is_scale_channel_wise = false;
|
||||||
if (cache_k_scale &&
|
if (cache_k_scale && cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) {
|
||||||
cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) {
|
|
||||||
is_scale_channel_wise = true;
|
is_scale_channel_wise = true;
|
||||||
}
|
}
|
||||||
if (is_scale_channel_wise) {
|
if (is_scale_channel_wise) {
|
||||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
append_decode_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
|
||||||
key_cache_out->data<uint8_t>(),
|
|
||||||
value_cache_out->data<uint8_t>(),
|
|
||||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
|
||||||
block_tables.data<int>(),
|
|
||||||
cu_seqlens_q.data<int>(),
|
|
||||||
seq_lens.data<int>(),
|
|
||||||
seq_lens_encoder.data<int>(),
|
|
||||||
cos_emb,
|
|
||||||
sin_emb,
|
|
||||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
|
||||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
|
||||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
|
||||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
|
||||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
max_seq_len,
|
|
||||||
max_blocks_per_seq,
|
|
||||||
num_heads,
|
|
||||||
kv_num_heads,
|
|
||||||
dim_head,
|
|
||||||
block_size,
|
|
||||||
bsz,
|
|
||||||
stream,
|
|
||||||
use_neox_rotary_style,
|
|
||||||
rope_3d);
|
|
||||||
} else {
|
|
||||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false>(
|
|
||||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
|
||||||
key_cache_out->data<uint8_t>(),
|
|
||||||
value_cache_out->data<uint8_t>(),
|
|
||||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
|
||||||
block_tables.data<int>(),
|
|
||||||
cu_seqlens_q.data<int>(),
|
|
||||||
seq_lens.data<int>(),
|
|
||||||
seq_lens_encoder.data<int>(),
|
|
||||||
cos_emb,
|
|
||||||
sin_emb,
|
|
||||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
|
||||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
|
||||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
|
||||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
|
||||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
max_seq_len,
|
|
||||||
max_blocks_per_seq,
|
|
||||||
num_heads,
|
|
||||||
kv_num_heads,
|
|
||||||
dim_head,
|
|
||||||
block_size,
|
|
||||||
bsz,
|
|
||||||
stream,
|
|
||||||
use_neox_rotary_style,
|
|
||||||
rope_3d);
|
|
||||||
}
|
|
||||||
} else if (cache_quant_type_str == "cache_fp8") {
|
|
||||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false, true>(
|
|
||||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||||
key_cache_out->data<uint8_t>(),
|
key_cache_out->data<uint8_t>(),
|
||||||
value_cache_out->data<uint8_t>(),
|
value_cache_out->data<uint8_t>(),
|
||||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||||
block_tables.data<int>(),
|
block_tables.data<int>(),
|
||||||
|
batch_id_per_token.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
seq_lens.data<int>(),
|
seq_lens.data<int>(),
|
||||||
seq_lens_encoder.data<int>(),
|
seq_lens_encoder.data<int>(),
|
||||||
@@ -799,8 +658,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
sin_emb,
|
sin_emb,
|
||||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
@@ -817,43 +676,73 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
stream,
|
stream,
|
||||||
use_neox_rotary_style,
|
use_neox_rotary_style,
|
||||||
rope_3d);
|
rope_3d);
|
||||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
} else {
|
||||||
constexpr int num_warps = 4;
|
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false>(
|
||||||
const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) /
|
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||||
num_warps * num_warps;
|
key_cache_out->data<uint8_t>(),
|
||||||
dim3 grids(bsz, all_warps / num_warps);
|
value_cache_out->data<uint8_t>(),
|
||||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_,
|
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||||
4,
|
block_tables.data<int>(),
|
||||||
0,
|
batch_id_per_token.data<int>(),
|
||||||
128,
|
cu_seqlens_q.data<int>(),
|
||||||
false,
|
seq_lens.data<int>(),
|
||||||
true>
|
seq_lens_encoder.data<int>(),
|
||||||
<<<grids, num_warps * 32, 0, stream>>>(
|
cos_emb,
|
||||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
sin_emb,
|
||||||
key_cache_out->data<uint8_t>(),
|
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||||
value_cache_out->data<uint8_t>(),
|
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||||
block_tables.data<int>(),
|
: nullptr,
|
||||||
cu_seqlens_q.data<int>(),
|
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||||
seq_lens.data<int>(),
|
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||||
seq_lens_encoder.data<int>(),
|
: nullptr,
|
||||||
cos_emb,
|
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||||
sin_emb,
|
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
|
: nullptr,
|
||||||
cache_k_scale.get().data<T>())),
|
max_seq_len,
|
||||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
|
max_blocks_per_seq,
|
||||||
(cache_v_scale.get().data<T>()))),
|
num_heads,
|
||||||
nullptr,
|
kv_num_heads,
|
||||||
nullptr,
|
dim_head,
|
||||||
max_seq_len,
|
block_size,
|
||||||
max_blocks_per_seq,
|
bsz,
|
||||||
num_heads,
|
stream,
|
||||||
block_size,
|
use_neox_rotary_style,
|
||||||
127.0f,
|
rope_3d);
|
||||||
-127.0f,
|
}
|
||||||
kv_num_heads,
|
} else if (cache_quant_type_str == "cache_fp8") {
|
||||||
rope_3d,
|
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false, true>(
|
||||||
rms_norm_eps);
|
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||||
|
key_cache_out->data<uint8_t>(),
|
||||||
|
value_cache_out->data<uint8_t>(),
|
||||||
|
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||||
|
block_tables.data<int>(),
|
||||||
|
batch_id_per_token.data<int>(),
|
||||||
|
cu_seqlens_q.data<int>(),
|
||||||
|
seq_lens.data<int>(),
|
||||||
|
seq_lens_encoder.data<int>(),
|
||||||
|
cos_emb,
|
||||||
|
sin_emb,
|
||||||
|
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||||
|
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||||
|
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||||
|
: nullptr,
|
||||||
|
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||||
|
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||||
|
: nullptr,
|
||||||
|
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||||
|
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||||
|
: nullptr,
|
||||||
|
max_seq_len,
|
||||||
|
max_blocks_per_seq,
|
||||||
|
num_heads,
|
||||||
|
kv_num_heads,
|
||||||
|
dim_head,
|
||||||
|
block_size,
|
||||||
|
bsz,
|
||||||
|
stream,
|
||||||
|
use_neox_rotary_style,
|
||||||
|
rope_3d);
|
||||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||||
append_decode_cache_int4_rope(
|
append_decode_cache_int4_rope(
|
||||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||||
@@ -861,6 +750,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
value_cache_out->data<uint8_t>(),
|
value_cache_out->data<uint8_t>(),
|
||||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||||
block_tables.data<int>(),
|
block_tables.data<int>(),
|
||||||
|
batch_id_per_token.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
seq_lens.data<int>(),
|
seq_lens.data<int>(),
|
||||||
seq_lens_encoder.data<int>(),
|
seq_lens_encoder.data<int>(),
|
||||||
@@ -868,8 +758,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
sin_emb,
|
sin_emb,
|
||||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
@@ -877,11 +767,11 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_blocks_per_seq,
|
max_blocks_per_seq,
|
||||||
num_heads,
|
num_heads,
|
||||||
@@ -900,6 +790,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||||
const AppendAttnMetaData& meta_data,
|
const AppendAttnMetaData& meta_data,
|
||||||
const paddle::Tensor&
|
const paddle::Tensor&
|
||||||
@@ -907,6 +798,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
|||||||
// kv_num_heads, head_dim] if GQA)
|
// kv_num_heads, head_dim] if GQA)
|
||||||
const paddle::Tensor& seq_lens,
|
const paddle::Tensor& seq_lens,
|
||||||
const paddle::Tensor& seq_lens_encoder,
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& batch_id_per_token,
|
||||||
const paddle::Tensor& cu_seqlens_q,
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||||
@@ -936,6 +828,7 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
|||||||
// kv_num_heads, head_dim] if GQA)
|
// kv_num_heads, head_dim] if GQA)
|
||||||
const paddle::Tensor& seq_lens,
|
const paddle::Tensor& seq_lens,
|
||||||
const paddle::Tensor& seq_lens_encoder,
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& batch_id_per_token,
|
||||||
const paddle::Tensor& cu_seqlens_q,
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||||
@@ -964,6 +857,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
|
|||||||
// kv_num_heads, head_dim] if GQA)
|
// kv_num_heads, head_dim] if GQA)
|
||||||
const paddle::Tensor& seq_lens,
|
const paddle::Tensor& seq_lens,
|
||||||
const paddle::Tensor& seq_lens_encoder,
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& batch_id_per_token,
|
||||||
const paddle::Tensor& cu_seqlens_q,
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||||
@@ -992,6 +886,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
|||||||
// kv_num_heads, head_dim] if GQA)
|
// kv_num_heads, head_dim] if GQA)
|
||||||
const paddle::Tensor& seq_lens,
|
const paddle::Tensor& seq_lens,
|
||||||
const paddle::Tensor& seq_lens_encoder,
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& batch_id_per_token,
|
||||||
const paddle::Tensor& cu_seqlens_q,
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
// kv_num_heads, head_dim] if GQA)
|
// kv_num_heads, head_dim] if GQA)
|
||||||
const paddle::Tensor& seq_lens,
|
const paddle::Tensor& seq_lens,
|
||||||
const paddle::Tensor& seq_lens_encoder,
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& batch_id_per_token,
|
||||||
const paddle::Tensor& cu_seqlens_q,
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
const paddle::optional<paddle::Tensor>& rotary_embs,
|
const paddle::optional<paddle::Tensor>& rotary_embs,
|
||||||
|
|||||||
@@ -449,8 +449,8 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
|
|||||||
const int half_lastdim = last_dim / 2;
|
const int half_lastdim = last_dim / 2;
|
||||||
const int offset = (q_num_head + kv_num_head) * last_dim;
|
const int offset = (q_num_head + kv_num_head) * last_dim;
|
||||||
const int all_head_num = elem_cnt / last_dim;
|
const int all_head_num = elem_cnt / last_dim;
|
||||||
for (int global_hi = global_warp_idx; global_hi < all_head_num; global_hi += all_warp_num) {
|
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) {
|
||||||
int64_t linear_index = global_hi * last_dim + threadIdx.x * VecSize;
|
int64_t linear_index = gloabl_hi * last_dim + threadIdx.x * VecSize;
|
||||||
const int token_idx = linear_index / offset;
|
const int token_idx = linear_index / offset;
|
||||||
const int ori_bi = batch_id_per_token[token_idx];
|
const int ori_bi = batch_id_per_token[token_idx];
|
||||||
if (seq_lens[ori_bi] == 0) continue;
|
if (seq_lens[ori_bi] == 0) continue;
|
||||||
@@ -1004,8 +1004,7 @@ __global__ void cache_kernel(
|
|||||||
const uint32_t qkv_bias = bias % hidden_size;
|
const uint32_t qkv_bias = bias % hidden_size;
|
||||||
const uint32_t hi = qkv_bias / head_size;
|
const uint32_t hi = qkv_bias / head_size;
|
||||||
const uint32_t h_bias = qkv_bias % head_size;
|
const uint32_t h_bias = qkv_bias % head_size;
|
||||||
const int32_t ori_bi = batch_id_per_token[token_idx];
|
const uint32_t ori_bi = batch_id_per_token[token_idx];
|
||||||
if (ori_bi == -1) continue; // skip batch_id_per_token[token_idx]=-1
|
|
||||||
if (seq_lens[ori_bi] == 0) continue;
|
if (seq_lens[ori_bi] == 0) continue;
|
||||||
const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||||
|
|
||||||
@@ -1301,411 +1300,6 @@ __global__ void append_write_cache_kv_c8_qkv(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T,
|
|
||||||
uint32_t num_frags_y,
|
|
||||||
uint32_t num_frags_z,
|
|
||||||
uint32_t HEAD_DIM,
|
|
||||||
uint32_t BLOCK_SIZE,
|
|
||||||
uint32_t NUM_WARPS,
|
|
||||||
bool is_need_kv_quant,
|
|
||||||
bool IsFP8 = true>
|
|
||||||
__global__ void append_write_cache_kv_c8_qkv_dynamic(
|
|
||||||
uint8_t *__restrict__ cache_k,
|
|
||||||
uint8_t *__restrict__ cache_v,
|
|
||||||
const T *__restrict__ qkv_input,
|
|
||||||
T *__restrict__ cache_k_scales, // [block_num, num_heads, block_size]
|
|
||||||
T *__restrict__ cache_v_scales, // [block_num, num_heads, block_size]
|
|
||||||
const int *__restrict__ batch_ids,
|
|
||||||
const int *__restrict__ tile_ids,
|
|
||||||
const int *__restrict__ seq_lens_this_time,
|
|
||||||
const int *__restrict__ seq_lens_decoder,
|
|
||||||
const int *__restrict__ batch_id_per_token,
|
|
||||||
const int *__restrict__ cu_seqlens_q,
|
|
||||||
const int *__restrict__ block_tables,
|
|
||||||
const int max_seq_len,
|
|
||||||
const int max_blocks_per_seq,
|
|
||||||
const int num_heads,
|
|
||||||
const int kv_num_heads) {
|
|
||||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
|
||||||
constexpr uint32_t pad_len = BLOCK_SIZE;
|
|
||||||
const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z;
|
|
||||||
const T cache_k_scale = cache_k_scales[kv_head_idx];
|
|
||||||
const T cache_v_scale = cache_v_scales[kv_head_idx];
|
|
||||||
const uint32_t tid = threadIdx.x, wid = threadIdx.y;
|
|
||||||
const uint32_t batch_id = batch_ids[btid];
|
|
||||||
const uint32_t tile_id = tile_ids[btid];
|
|
||||||
const uint32_t seq_len_this_time = seq_lens_this_time[batch_id];
|
|
||||||
if (seq_len_this_time <= 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const int *block_table_now = nullptr;
|
|
||||||
|
|
||||||
block_table_now = block_tables + batch_id * max_blocks_per_seq;
|
|
||||||
|
|
||||||
const uint32_t num_rows_per_block =
|
|
||||||
NUM_WARPS * num_frags_z * 16; // BLOCK_SIZE
|
|
||||||
const uint32_t start_len = seq_lens_decoder[batch_id];
|
|
||||||
const uint32_t bf_pad_len = start_len % pad_len;
|
|
||||||
const uint32_t start_len_pad = start_len - bf_pad_len;
|
|
||||||
const uint32_t end_len = start_len + seq_len_this_time;
|
|
||||||
|
|
||||||
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
|
|
||||||
int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]);
|
|
||||||
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
|
|
||||||
|
|
||||||
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
|
|
||||||
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
|
|
||||||
const uint32_t kv_h_stride = HEAD_DIM;
|
|
||||||
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
|
|
||||||
__shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM];
|
|
||||||
__shared__ T v_scale_smem[BLOCK_SIZE];
|
|
||||||
if (tile_start >= start_len) {
|
|
||||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
|
||||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
|
||||||
// pad zero for this kv_head_idx for this block
|
|
||||||
LoadPadKVT pad_cache_vec;
|
|
||||||
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
|
||||||
// reset k
|
|
||||||
constexpr int num_vecs_per_head_k = HEAD_DIM / KV_VEC_SIZE;
|
|
||||||
constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k;
|
|
||||||
uint32_t tgt_idx =
|
|
||||||
(block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM +
|
|
||||||
tid % num_vecs_per_head_k * KV_VEC_SIZE;
|
|
||||||
for (int block_i = tid / num_vecs_per_head_k;
|
|
||||||
block_i < BLOCK_SIZE;
|
|
||||||
block_i += num_token_each_time_k) {
|
|
||||||
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
|
|
||||||
&cache_k[tgt_idx + block_i * HEAD_DIM]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// reset v
|
|
||||||
const int num_vecs_per_head_v = BLOCK_SIZE / KV_VEC_SIZE;
|
|
||||||
const int num_token_each_time_v = 32 / num_vecs_per_head_v;
|
|
||||||
tgt_idx =
|
|
||||||
(block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE +
|
|
||||||
tid % num_vecs_per_head_v * KV_VEC_SIZE;
|
|
||||||
for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM;
|
|
||||||
block_i += num_token_each_time_v) {
|
|
||||||
Store<uint8_t, KV_VEC_SIZE>(
|
|
||||||
pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
smem_t k_smem(k_smem_ori);
|
|
||||||
smem_t v_smem(v_smem_ori);
|
|
||||||
|
|
||||||
uint32_t kv_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
|
|
||||||
wid * num_frags_z * 16 + tid / 8, tid % 8); // 4 * 8 per warp
|
|
||||||
|
|
||||||
/*
|
|
||||||
0 | 1
|
|
||||||
2 | 3
|
|
||||||
*/
|
|
||||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
|
||||||
wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
|
||||||
|
|
||||||
constexpr uint32_t num_frags_v = num_frags_y / NUM_WARPS;
|
|
||||||
/*
|
|
||||||
0 | 2
|
|
||||||
1 | 3
|
|
||||||
*/
|
|
||||||
uint32_t v_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
|
||||||
tid % 16, wid * num_frags_v * 2 + tid / 16);
|
|
||||||
|
|
||||||
// load kv gmem to smem
|
|
||||||
const uint32_t real_start_token_idx = start_token_idx - bf_pad_len +
|
|
||||||
tile_id * num_rows_per_block +
|
|
||||||
wid * num_frags_z * 16 + tid / 8;
|
|
||||||
uint32_t k_read_idx = real_start_token_idx * kv_batch_stride +
|
|
||||||
(num_heads + kv_head_idx) * kv_h_stride +
|
|
||||||
tid % 8 * num_elems_per_128b<T>();
|
|
||||||
uint32_t v_read_idx = real_start_token_idx * kv_batch_stride +
|
|
||||||
(num_heads + kv_num_heads + kv_head_idx) * kv_h_stride +
|
|
||||||
tid % 8 * num_elems_per_128b<T>();
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t j = 0; j < 4; ++j) {
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t fy = 0; fy < num_frags_y / 4;
|
|
||||||
++fy) { // (num_frags_y * 16) / (8 * num_elems_per_128b<T>())
|
|
||||||
if (chunk_start >= start_len && chunk_start < end_len) {
|
|
||||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
|
||||||
kv_smem_offset_w, qkv_input + k_read_idx, chunk_start < end_len);
|
|
||||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
|
||||||
kv_smem_offset_w, qkv_input + v_read_idx, chunk_start < end_len);
|
|
||||||
}
|
|
||||||
kv_smem_offset_w =
|
|
||||||
k_smem.advance_offset_by_column<8>(kv_smem_offset_w, fy);
|
|
||||||
k_read_idx += 8 * num_elems_per_128b<T>();
|
|
||||||
v_read_idx += 8 * num_elems_per_128b<T>();
|
|
||||||
}
|
|
||||||
kv_smem_offset_w =
|
|
||||||
k_smem.advance_offset_by_row<4, num_vecs_per_head>(kv_smem_offset_w) -
|
|
||||||
2 * num_frags_y;
|
|
||||||
chunk_start += 4;
|
|
||||||
k_read_idx +=
|
|
||||||
4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b<T>();
|
|
||||||
v_read_idx +=
|
|
||||||
4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b<T>();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
commit_group();
|
|
||||||
wait_group<0>();
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// reduce scale
|
|
||||||
// 16 rows per warp
|
|
||||||
uint32_t kv_reduce_frag[4];
|
|
||||||
T *kv_reduce_frag_T = reinterpret_cast<T*>(kv_reduce_frag);
|
|
||||||
|
|
||||||
T k_local_max_value[num_frags_z * 2];
|
|
||||||
T v_local_max_value[num_frags_z * 2];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < num_frags_z * 2; i++) {
|
|
||||||
k_local_max_value[i] = -INFINITY;
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < num_frags_z * 2; i++) {
|
|
||||||
v_local_max_value[i] = -INFINITY;
|
|
||||||
}
|
|
||||||
const int num_kv_heads = gridDim.z;
|
|
||||||
const int scale_offset = block_id * num_kv_heads * BLOCK_SIZE + kv_head_idx * BLOCK_SIZE;
|
|
||||||
T *cache_k_scale_now = cache_k_scales + scale_offset;
|
|
||||||
T *cache_v_scale_now = cache_v_scales + scale_offset;
|
|
||||||
// k scale
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
|
||||||
// reduce per thread, 4 threads each row
|
|
||||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag);
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 4; i++) {
|
|
||||||
k_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), k_local_max_value[fz * 2]);
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 4; i++) {
|
|
||||||
k_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), k_local_max_value[fz * 2 + 1]);
|
|
||||||
}
|
|
||||||
k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
|
|
||||||
}
|
|
||||||
// reduce per row
|
|
||||||
for (int i = 0; i < 2; i++) {
|
|
||||||
T local_max_value = __habs(k_local_max_value[fz * 2 + i]);
|
|
||||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2));
|
|
||||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1));
|
|
||||||
// used for quant
|
|
||||||
k_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value);
|
|
||||||
}
|
|
||||||
// store
|
|
||||||
if (tid % 4 == 0) {
|
|
||||||
const int offset_now = wid * num_frags_z * 16 + tid / 4;
|
|
||||||
// used for dequant
|
|
||||||
if (tile_start + offset_now >= start_len) {
|
|
||||||
if (tile_start + offset_now < end_len) {
|
|
||||||
cache_k_scale_now[offset_now] = __hdiv(1, k_local_max_value[fz * 2]);
|
|
||||||
} else {
|
|
||||||
cache_k_scale_now[offset_now] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (tile_start + offset_now + 8 >= start_len) {
|
|
||||||
if (tile_start + offset_now + 8 < end_len) {
|
|
||||||
cache_k_scale_now[offset_now + 8] = __hdiv(1, k_local_max_value[fz * 2 + 1]);
|
|
||||||
} else {
|
|
||||||
cache_k_scale_now[offset_now + 8] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1
|
|
||||||
}
|
|
||||||
// v scale
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
|
||||||
// reduce per thread, 4 threads each row
|
|
||||||
v_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag);
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 4; i++) {
|
|
||||||
v_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), v_local_max_value[fz * 2]);
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 4; i++) {
|
|
||||||
v_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), v_local_max_value[fz * 2 + 1]);
|
|
||||||
}
|
|
||||||
k_smem_offset_r = v_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
|
|
||||||
}
|
|
||||||
// reduce per row
|
|
||||||
for (int i = 0; i < 2; i++) {
|
|
||||||
T local_max_value = __habs(v_local_max_value[fz * 2 + i]);
|
|
||||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2));
|
|
||||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1));
|
|
||||||
v_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value);
|
|
||||||
}
|
|
||||||
// store
|
|
||||||
if (tid % 4 == 0) {
|
|
||||||
const int offset_now = wid * num_frags_z * 16 + tid / 4;
|
|
||||||
// used for dequant
|
|
||||||
if (tile_start + offset_now >= start_len) {
|
|
||||||
if (tile_start + offset_now < end_len) {
|
|
||||||
cache_v_scale_now[offset_now] = __hdiv(1, v_local_max_value[fz * 2]);
|
|
||||||
v_scale_smem[offset_now] = v_local_max_value[fz * 2];
|
|
||||||
} else {
|
|
||||||
cache_v_scale_now[offset_now] = 0;
|
|
||||||
v_scale_smem[offset_now] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (tile_start + offset_now + 8 >= start_len) {
|
|
||||||
if (tile_start + offset_now + 8 < end_len) {
|
|
||||||
cache_v_scale_now[offset_now + 8] = __hdiv(1, v_local_max_value[fz * 2 + 1]);
|
|
||||||
v_scale_smem[offset_now + 8] = v_local_max_value[fz * 2 + 1];
|
|
||||||
} else {
|
|
||||||
cache_v_scale_now[offset_now + 8] = 0;
|
|
||||||
v_scale_smem[offset_now + 8] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// mask, quant, store
|
|
||||||
using LoadKVT = AlignedVector<uint8_t, 4>;
|
|
||||||
LoadKVT cache_vec1;
|
|
||||||
LoadKVT cache_vec2;
|
|
||||||
|
|
||||||
uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4;
|
|
||||||
uint32_t kv_frag[4];
|
|
||||||
const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
|
|
||||||
const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM;
|
|
||||||
const uint32_t write_b_stride = HEAD_DIM;
|
|
||||||
const uint32_t write_d_stride = BLOCK_SIZE;
|
|
||||||
uint32_t k_write_idx = block_id * write_n_stride +
|
|
||||||
kv_head_idx * write_h_stride +
|
|
||||||
(wid * num_frags_z * 16 + tid / 4) * write_b_stride +
|
|
||||||
tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
|
||||||
uint32_t k_write_idx_now_z = k_write_idx + fz * 16 * write_b_stride;
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
|
||||||
uint32_t k_write_idx_now = k_write_idx_now_z +
|
|
||||||
fy % 2 * 8 * write_b_stride +
|
|
||||||
fy / 2 * 32; // + fy % 2 * 16;
|
|
||||||
// load
|
|
||||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
|
|
||||||
// quant
|
|
||||||
T *k_frag_T = reinterpret_cast<T *>(kv_frag);
|
|
||||||
if (bf_pad_len != 0) {
|
|
||||||
Load<uint8_t, 4>(cache_k + k_write_idx_now, &cache_vec1);
|
|
||||||
Load<uint8_t, 4>(cache_k + k_write_idx_now + 16, &cache_vec2);
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t v_id = 0; v_id < 8; ++v_id) {
|
|
||||||
uint8_t uint_quant_value;
|
|
||||||
if (chunk_start_k + (v_id / 4) * 8 >= start_len &&
|
|
||||||
chunk_start_k + (v_id / 4) * 8 < end_len) {
|
|
||||||
uint_quant_value = QuantToC8<T, is_need_kv_quant, IsFP8>(k_local_max_value[fz * 2 + v_id / 4], k_frag_T[v_id], 127.0f, -127.0f);
|
|
||||||
} else {
|
|
||||||
uint_quant_value = 0;
|
|
||||||
}
|
|
||||||
if (bf_pad_len != 0) {
|
|
||||||
if (v_id < 4) {
|
|
||||||
cache_vec1[v_id] |= uint_quant_value;
|
|
||||||
} else {
|
|
||||||
cache_vec2[v_id % 4] |= uint_quant_value;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (v_id < 4) {
|
|
||||||
cache_vec1[v_id] = uint_quant_value;
|
|
||||||
} else {
|
|
||||||
cache_vec2[v_id - 4] = uint_quant_value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// store
|
|
||||||
Store<uint8_t, 4>(cache_vec1, cache_k + k_write_idx_now);
|
|
||||||
Store<uint8_t, 4>(cache_vec2, cache_k + k_write_idx_now + 16);
|
|
||||||
k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
|
|
||||||
}
|
|
||||||
k_smem_offset_r =
|
|
||||||
k_smem.advance_offset_by_row<16, num_vecs_per_head>(k_smem_offset_r) -
|
|
||||||
2 * num_frags_y;
|
|
||||||
chunk_start_k += 16;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t chunk_start_v = tile_start + tid % 4 * 2;
|
|
||||||
uint32_t v_write_idx = block_id * write_n_stride +
|
|
||||||
kv_head_idx * write_h_stride +
|
|
||||||
(wid * num_frags_v * 16 + tid / 4) * write_d_stride +
|
|
||||||
tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit
|
|
||||||
const uint32_t num_frags_z_v = num_frags_z * NUM_WARPS;
|
|
||||||
T v_scales[num_frags_z_v * 4];
|
|
||||||
for (int v_i = 0; v_i < num_frags_z_v; v_i++) {
|
|
||||||
const int offset = v_i * 16;
|
|
||||||
const int t_offset = tid % 4 * 2;
|
|
||||||
v_scales[v_i * 4] = v_scale_smem[offset + t_offset];
|
|
||||||
v_scales[v_i * 4 + 1] = v_scale_smem[offset + t_offset + 1];
|
|
||||||
v_scales[v_i * 4 + 2] = v_scale_smem[offset + t_offset + 8];
|
|
||||||
v_scales[v_i * 4 + 3] = v_scale_smem[offset + t_offset + 9];
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t fy = 0; fy < num_frags_v; ++fy) {
|
|
||||||
uint32_t v_write_idx_now_v = v_write_idx + fy * 16 * write_d_stride;
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t fz = 0; fz < num_frags_z_v; ++fz) {
|
|
||||||
uint32_t v_write_idx_now = v_write_idx_now_v +
|
|
||||||
fz % 2 * 8 * write_d_stride +
|
|
||||||
fz / 2 * 32; // + fz % 2 * 16;
|
|
||||||
// load
|
|
||||||
v_smem.ldmatrix_m8n8x4_trans(v_smem_offset_r, kv_frag);
|
|
||||||
// quant
|
|
||||||
T *v_frag_T = reinterpret_cast<T *>(kv_frag);
|
|
||||||
if (bf_pad_len != 0) {
|
|
||||||
Load<uint8_t, 4>(cache_v + v_write_idx_now, &cache_vec1);
|
|
||||||
Load<uint8_t, 4>(cache_v + v_write_idx_now + 16, &cache_vec2);
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t v_id = 0; v_id < 8; ++v_id) {
|
|
||||||
uint8_t uint_quant_value;
|
|
||||||
if (chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 >= start_len &&
|
|
||||||
chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 < end_len) {
|
|
||||||
uint_quant_value = QuantToC8<T, is_need_kv_quant, IsFP8>(v_scales[fz * 4 + v_id % 4], v_frag_T[v_id], 127.0f, -127.0f);
|
|
||||||
// store now
|
|
||||||
} else {
|
|
||||||
uint_quant_value = 0;
|
|
||||||
}
|
|
||||||
if (bf_pad_len != 0) {
|
|
||||||
if (v_id < 4) {
|
|
||||||
cache_vec1[v_id] |= uint_quant_value;
|
|
||||||
} else {
|
|
||||||
cache_vec2[v_id % 4] |= uint_quant_value;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (v_id < 4) {
|
|
||||||
cache_vec1[v_id] = uint_quant_value;
|
|
||||||
} else {
|
|
||||||
cache_vec2[v_id % 4] = uint_quant_value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// store
|
|
||||||
Store<uint8_t, 4>(cache_vec1, cache_v + v_write_idx_now);
|
|
||||||
Store<uint8_t, 4>(cache_vec2, cache_v + v_write_idx_now + 16);
|
|
||||||
chunk_start_v += 16;
|
|
||||||
v_smem_offset_r =
|
|
||||||
k_smem.advance_offset_by_row<16, num_vecs_per_head>(v_smem_offset_r);
|
|
||||||
}
|
|
||||||
v_smem_offset_r = k_smem.advance_offset_by_column<2>(
|
|
||||||
v_smem_offset_r, wid * num_frags_v + fy) -
|
|
||||||
16 * num_frags_z_v * num_vecs_per_head;
|
|
||||||
chunk_start_v -= 16 * num_frags_z_v;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write Cache KV in Append
|
// Write Cache KV in Append
|
||||||
template <typename T,
|
template <typename T,
|
||||||
uint32_t num_frags_y,
|
uint32_t num_frags_y,
|
||||||
@@ -2179,9 +1773,7 @@ void gqa_rotary_qk_norm_variable(
|
|||||||
qkv_out_scales
|
qkv_out_scales
|
||||||
? token_num * (num_heads + 2 * kv_num_heads) * dim_head
|
? token_num * (num_heads + 2 * kv_num_heads) * dim_head
|
||||||
: token_num * (num_heads + kv_num_heads) * dim_head; // for all q k v
|
: token_num * (num_heads + kv_num_heads) * dim_head; // for all q k v
|
||||||
if (dim_head != 128) {
|
assert(dim_head == 128 && "dim_head must be 128");
|
||||||
PADDLE_THROW("gqa rotary with qk norm only support head_dim=128, but got %d.", dim_head);
|
|
||||||
}
|
|
||||||
constexpr int HEAD_DIM = 128;
|
constexpr int HEAD_DIM = 128;
|
||||||
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
||||||
const int pack_num = elem_nums / PackSize;
|
const int pack_num = elem_nums / PackSize;
|
||||||
@@ -2515,11 +2107,10 @@ void CascadeAppendWriteCacheKVC8QKV(
|
|||||||
int num_blocks_x_cpu,
|
int num_blocks_x_cpu,
|
||||||
int max_seq_len,
|
int max_seq_len,
|
||||||
bool is_scale_channel_wise,
|
bool is_scale_channel_wise,
|
||||||
const std::string& cache_quant_type,
|
const bool is_fp8,
|
||||||
cudaStream_t &stream,
|
cudaStream_t &stream,
|
||||||
paddle::Tensor *cache_k_out,
|
paddle::Tensor *cache_k_out,
|
||||||
paddle::Tensor *cache_v_out) {
|
paddle::Tensor *cache_v_out) {
|
||||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
|
||||||
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
||||||
auto num_tokens = meta_data.token_nums;
|
auto num_tokens = meta_data.token_nums;
|
||||||
auto num_heads = meta_data.q_num_heads;
|
auto num_heads = meta_data.q_num_heads;
|
||||||
@@ -2537,77 +2128,49 @@ void CascadeAppendWriteCacheKVC8QKV(
|
|||||||
dim3 blocks(32, num_warps);
|
dim3 blocks(32, num_warps);
|
||||||
|
|
||||||
const uint32_t smem_size = (BLOCK_SIZE * HEAD_DIM) * sizeof(T) * 2;
|
const uint32_t smem_size = (BLOCK_SIZE * HEAD_DIM) * sizeof(T) * 2;
|
||||||
if (cache_quant_type != "block_wise_fp8") {
|
auto kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||||
auto kernel_fn = append_write_cache_kv_c8_qkv<T,
|
num_frags_y,
|
||||||
num_frags_y,
|
num_frags_z,
|
||||||
num_frags_z,
|
HEAD_DIM,
|
||||||
HEAD_DIM,
|
BLOCK_SIZE,
|
||||||
BLOCK_SIZE,
|
num_warps,
|
||||||
num_warps,
|
true, false>;
|
||||||
true, false>;
|
if (is_fp8) {
|
||||||
if (cache_quant_type == "cache_fp8") {
|
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
num_frags_y,
|
||||||
num_frags_y,
|
num_frags_z,
|
||||||
num_frags_z,
|
HEAD_DIM,
|
||||||
HEAD_DIM,
|
BLOCK_SIZE,
|
||||||
BLOCK_SIZE,
|
num_warps,
|
||||||
num_warps,
|
true, true>;
|
||||||
true, true>;
|
|
||||||
}
|
|
||||||
if (is_scale_channel_wise) {
|
|
||||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
|
||||||
num_frags_y,
|
|
||||||
num_frags_z,
|
|
||||||
HEAD_DIM,
|
|
||||||
BLOCK_SIZE,
|
|
||||||
num_warps,
|
|
||||||
false>;
|
|
||||||
}
|
|
||||||
cudaFuncSetAttribute(
|
|
||||||
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
|
||||||
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
|
|
||||||
cache_v_out->data<uint8_t>(),
|
|
||||||
qkv.data<T>(),
|
|
||||||
cache_k_scale.data<T>(),
|
|
||||||
cache_v_scale.data<T>(),
|
|
||||||
batch_ids.data<int>(),
|
|
||||||
tile_ids_per_batch.data<int>(),
|
|
||||||
seq_lens_this_time.data<int>(),
|
|
||||||
seq_lens_decoder.data<int>(),
|
|
||||||
batch_id_per_token.data<int>(),
|
|
||||||
cu_seqlens_q.data<int>(),
|
|
||||||
block_table.data<int>(),
|
|
||||||
max_seq_len,
|
|
||||||
max_blocks_per_seq,
|
|
||||||
num_heads,
|
|
||||||
kv_num_heads);
|
|
||||||
} else {
|
|
||||||
auto kernel_fn = append_write_cache_kv_c8_qkv_dynamic<NV_TYPE,
|
|
||||||
num_frags_y,
|
|
||||||
num_frags_z,
|
|
||||||
HEAD_DIM,
|
|
||||||
BLOCK_SIZE,
|
|
||||||
num_warps,
|
|
||||||
true, true>;
|
|
||||||
cudaFuncSetAttribute(
|
|
||||||
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
|
||||||
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
|
|
||||||
cache_v_out->data<uint8_t>(),
|
|
||||||
reinterpret_cast<const NV_TYPE*>(qkv.data<T>()),
|
|
||||||
const_cast<NV_TYPE*>(reinterpret_cast<const NV_TYPE*>(cache_k_scale.data<T>())),
|
|
||||||
const_cast<NV_TYPE*>(reinterpret_cast<const NV_TYPE*>(cache_v_scale.data<T>())),
|
|
||||||
batch_ids.data<int>(),
|
|
||||||
tile_ids_per_batch.data<int>(),
|
|
||||||
seq_lens_this_time.data<int>(),
|
|
||||||
seq_lens_decoder.data<int>(),
|
|
||||||
batch_id_per_token.data<int>(),
|
|
||||||
cu_seqlens_q.data<int>(),
|
|
||||||
block_table.data<int>(),
|
|
||||||
max_seq_len,
|
|
||||||
max_blocks_per_seq,
|
|
||||||
num_heads,
|
|
||||||
kv_num_heads);
|
|
||||||
}
|
}
|
||||||
|
if (is_scale_channel_wise) {
|
||||||
|
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||||
|
num_frags_y,
|
||||||
|
num_frags_z,
|
||||||
|
HEAD_DIM,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
num_warps,
|
||||||
|
false>;
|
||||||
|
}
|
||||||
|
cudaFuncSetAttribute(
|
||||||
|
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||||
|
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
|
||||||
|
cache_v_out->data<uint8_t>(),
|
||||||
|
qkv.data<T>(),
|
||||||
|
cache_k_scale.data<T>(),
|
||||||
|
cache_v_scale.data<T>(),
|
||||||
|
batch_ids.data<int>(),
|
||||||
|
tile_ids_per_batch.data<int>(),
|
||||||
|
seq_lens_this_time.data<int>(),
|
||||||
|
seq_lens_decoder.data<int>(),
|
||||||
|
batch_id_per_token.data<int>(),
|
||||||
|
cu_seqlens_q.data<int>(),
|
||||||
|
block_table.data<int>(),
|
||||||
|
max_seq_len,
|
||||||
|
max_blocks_per_seq,
|
||||||
|
num_heads,
|
||||||
|
kv_num_heads);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>
|
template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ void EncoderWriteCacheWithRopeKernel(
|
|||||||
stream,
|
stream,
|
||||||
key_cache_out,
|
key_cache_out,
|
||||||
value_cache_out);
|
value_cache_out);
|
||||||
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
|
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8") {
|
||||||
DISPATCH_HEAD_DIM(
|
DISPATCH_HEAD_DIM(
|
||||||
head_dim, HEAD_DIM, {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {
|
head_dim, HEAD_DIM, {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {
|
||||||
CascadeAppendWriteCacheKVC8QKV<T, HEAD_DIM, BLOCK_SIZE>(
|
CascadeAppendWriteCacheKVC8QKV<T, HEAD_DIM, BLOCK_SIZE>(
|
||||||
@@ -198,7 +198,7 @@ void EncoderWriteCacheWithRopeKernel(
|
|||||||
num_blocks,
|
num_blocks,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
is_scale_channel_wise,
|
is_scale_channel_wise,
|
||||||
cache_quant_type_str,
|
cache_quant_type_str == "cache_fp8",
|
||||||
stream,
|
stream,
|
||||||
key_cache_out,
|
key_cache_out,
|
||||||
value_cache_out);
|
value_cache_out);
|
||||||
|
|||||||
@@ -11,17 +11,14 @@
|
|||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
#include "cute/tensor.hpp"
|
|
||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
#include "paddle/extension.h"
|
#include "paddle/extension.h"
|
||||||
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
|
||||||
#include "paddle/phi/core/memory/memcpy.h"
|
#include "paddle/phi/core/memory/memcpy.h"
|
||||||
#endif
|
|
||||||
#include "utils.cuh"
|
|
||||||
|
|
||||||
template <int THREADBLOCK_SIZE>
|
template <int THREADBLOCK_SIZE>
|
||||||
__global__ void
|
__global__ void
|
||||||
GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time,
|
GetMaxLenKernel(const int *seq_lens, const int *seq_lens_this_time,
|
||||||
const int *seq_lens_encoder,
|
const int *seq_lens_encoder,
|
||||||
const int *seq_lens_this_time_merged,
|
const int *seq_lens_this_time_merged,
|
||||||
const int *seq_lens_encoder_merged, const int *seq_mapping,
|
const int *seq_lens_encoder_merged, const int *seq_mapping,
|
||||||
@@ -39,27 +36,41 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time,
|
|||||||
int max_just_dec_merged_len_this_time_this_thread = 0;
|
int max_just_dec_merged_len_this_time_this_thread = 0;
|
||||||
int max_system_len_this_thread = 0;
|
int max_system_len_this_thread = 0;
|
||||||
int max_dec_len_without_system_this_thread = 0;
|
int max_dec_len_without_system_this_thread = 0;
|
||||||
int max_len_kv_this_thread = 0;
|
|
||||||
for (int i = tid; i < batch_size; i += blockDim.x) {
|
for (int i = tid; i < batch_size; i += blockDim.x) {
|
||||||
const int seq_len_this_time = seq_lens_this_time[i];
|
const int seq_len_this_time = seq_lens_this_time[i];
|
||||||
const int seq_len_decoder = seq_lens_decoder[i];
|
|
||||||
max_len_this_time_this_thread =
|
max_len_this_time_this_thread =
|
||||||
max(seq_len_this_time, max_len_this_time_this_thread);
|
max(seq_len_this_time, max_len_this_time_this_thread);
|
||||||
max_len_encoder_this_thread =
|
max_len_encoder_this_thread =
|
||||||
max(seq_lens_encoder[i], max_len_encoder_this_thread);
|
max(seq_lens_encoder[i], max_len_encoder_this_thread);
|
||||||
max_len_decoder_this_thread = max(seq_len_decoder, max_len_decoder_this_thread);
|
max_len_decoder_this_thread = max(seq_lens[i], max_len_decoder_this_thread);
|
||||||
if (seq_len_this_time <= 0)
|
if (seq_len_this_time <= 0)
|
||||||
continue;
|
continue;
|
||||||
const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder;
|
const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_lens[i];
|
||||||
max_len_this_thread =
|
max_len_this_thread =
|
||||||
max(seq_len_decoder + seq_len_this_time, max_len_this_thread);
|
max(seq_lens[i] + seq_len_this_time, max_len_this_thread);
|
||||||
max_just_dec_len_this_thread =
|
max_just_dec_len_this_thread =
|
||||||
max(max_just_dec_len_this_thread, max_just_dec_len_now);
|
max(max_just_dec_len_this_thread, max_just_dec_len_now);
|
||||||
|
if (system_lens) {
|
||||||
if (seq_len_decoder == 0)
|
const int real_bid = seq_mapping[i];
|
||||||
continue;
|
const int system_len_now = system_lens[real_bid];
|
||||||
max_len_kv_this_thread =
|
max_system_len_this_thread =
|
||||||
max(seq_len_this_time + seq_len_decoder, max_len_kv_this_thread);
|
max(max_system_len_this_thread, system_len_now);
|
||||||
|
max_dec_len_without_system_this_thread =
|
||||||
|
max(max_dec_len_without_system_this_thread,
|
||||||
|
max_just_dec_len_now - system_len_now);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (system_lens) {
|
||||||
|
for (int i = tid; i < batch_size; i += blockDim.x) {
|
||||||
|
const int ori_seq_len_this_time = seq_lens_this_time_merged[i];
|
||||||
|
if (ori_seq_len_this_time <= 0)
|
||||||
|
continue;
|
||||||
|
const int max_just_dec_merged_len_this_time_now =
|
||||||
|
seq_lens_encoder_merged[i] > 0 ? 0 : ori_seq_len_this_time;
|
||||||
|
max_just_dec_merged_len_this_time_this_thread =
|
||||||
|
max(max_just_dec_merged_len_this_time_this_thread,
|
||||||
|
max_just_dec_merged_len_this_time_now);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
int total_max_len_this_time =
|
int total_max_len_this_time =
|
||||||
BlockReduce(temp_storage)
|
BlockReduce(temp_storage)
|
||||||
@@ -82,8 +93,6 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time,
|
|||||||
int total_dec_len_without_system =
|
int total_dec_len_without_system =
|
||||||
BlockReduce(temp_storage)
|
BlockReduce(temp_storage)
|
||||||
.Reduce(max_dec_len_without_system_this_thread, MaxOp<int>());
|
.Reduce(max_dec_len_without_system_this_thread, MaxOp<int>());
|
||||||
int total_max_len_kv =
|
|
||||||
BlockReduce(temp_storage).Reduce(max_len_kv_this_thread, MaxOp<int>());
|
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
max_lens[0] = total_max_len_this_time;
|
max_lens[0] = total_max_len_this_time;
|
||||||
max_lens[1] = total_max_len_encoder;
|
max_lens[1] = total_max_len_encoder;
|
||||||
@@ -93,7 +102,6 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time,
|
|||||||
max_lens[5] = total_just_dec_merged;
|
max_lens[5] = total_just_dec_merged;
|
||||||
max_lens[6] = total_system_len;
|
max_lens[6] = total_system_len;
|
||||||
max_lens[7] = total_dec_len_without_system;
|
max_lens[7] = total_dec_len_without_system;
|
||||||
max_lens[8] = total_max_len_kv;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,146 +116,29 @@ void GetMaxLen(const paddle::Tensor &seq_lens_tensor,
|
|||||||
max_len_tensor.data<int>(), batch_size);
|
max_len_tensor.data<int>(), batch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <uint32_t config_size>
|
|
||||||
__global__ void search_chunk_size_for_mla(
|
|
||||||
const int *__restrict__ seq_lens_q,
|
|
||||||
const int *__restrict__ seq_lens_encoder,
|
|
||||||
const int *__restrict__ seq_lens_decoder,
|
|
||||||
int *__restrict__ num_blocks_x,
|
|
||||||
int *__restrict__ res_chunk_size,
|
|
||||||
const int bsz,
|
|
||||||
const int set_chunk_size,
|
|
||||||
const int block_size,
|
|
||||||
const int sm_cout) {
|
|
||||||
const uint32_t conf_id = threadIdx.x;
|
|
||||||
int gridx = 0;
|
|
||||||
if (set_chunk_size > 0 && conf_id == 0) {
|
|
||||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
|
||||||
int seq_len = seq_lens_q[bid];
|
|
||||||
int seq_len_encoder = seq_lens_encoder[bid];
|
|
||||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
|
||||||
if (seq_len == 0 || seq_len_encoder > 0) continue;
|
|
||||||
|
|
||||||
int loop_times;
|
|
||||||
loop_times = cute::ceil_div(seq_len_decoder, set_chunk_size);
|
|
||||||
gridx += loop_times;
|
|
||||||
}
|
|
||||||
*num_blocks_x = gridx;
|
|
||||||
*res_chunk_size = set_chunk_size;
|
|
||||||
} else if (conf_id < config_size) {
|
|
||||||
__shared__ int gridx_shared[config_size];
|
|
||||||
// chunk_size is a multiple of 64
|
|
||||||
const int chunk_size = block_size << conf_id;
|
|
||||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
|
||||||
int seq_len = seq_lens_q[bid];
|
|
||||||
int seq_len_encoder = seq_lens_encoder[bid];
|
|
||||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
|
||||||
if (seq_len == 0 || seq_len_encoder > 0) continue;
|
|
||||||
|
|
||||||
int loop_times;
|
|
||||||
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
|
|
||||||
gridx += loop_times;
|
|
||||||
}
|
|
||||||
gridx_shared[conf_id] = gridx;
|
|
||||||
__syncthreads();
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
uint32_t res_id = 0;
|
|
||||||
uint32_t max_last_wave_block = 0;
|
|
||||||
for (uint32_t i = 1; i < config_size; i++) {
|
|
||||||
uint32_t last_wave_block = gridx_shared[i] % sm_cout;
|
|
||||||
if (last_wave_block >= max_last_wave_block) {
|
|
||||||
res_id = i;
|
|
||||||
max_last_wave_block = last_wave_block;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*num_blocks_x = gridx_shared[res_id];
|
|
||||||
*res_chunk_size = block_size << res_id;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void split_block_for_mla(const int *__restrict__ seq_lens_q,
|
|
||||||
const int *__restrict__ seq_lens_encoder,
|
|
||||||
const int *__restrict__ seq_lens_decoder,
|
|
||||||
int *__restrict__ batch_ids,
|
|
||||||
int *__restrict__ tile_ids_per_batch,
|
|
||||||
const int bsz,
|
|
||||||
const int chunk_size) {
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
int index = 0;
|
|
||||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
|
||||||
int seq_len = seq_lens_q[bid];
|
|
||||||
int seq_len_encoder = seq_lens_encoder[bid];
|
|
||||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
|
||||||
|
|
||||||
if (seq_len == 0) continue;
|
|
||||||
|
|
||||||
int loop_times;
|
|
||||||
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
|
|
||||||
if (seq_len_encoder > 0) {
|
|
||||||
loop_times = 0;
|
|
||||||
}
|
|
||||||
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
|
|
||||||
batch_ids[index] = bid;
|
|
||||||
tile_ids_per_batch[index++] = tile_id;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
__global__ void split_q_block(const int *__restrict__ seq_lens_q,
|
__global__ void split_q_block(const int *__restrict__ seq_lens_q,
|
||||||
const int *__restrict__ seq_lens_encoder,
|
const int *__restrict__ seq_lens_encoder,
|
||||||
int *__restrict__ batch_ids,
|
int *__restrict__ batch_ids,
|
||||||
int *__restrict__ tile_ids_per_batch,
|
int *__restrict__ tile_ids_per_batch,
|
||||||
int *__restrict__ num_blocks_x,
|
int *__restrict__ num_blocks_x, const int bsz,
|
||||||
const int bsz,
|
|
||||||
const int num_rows_per_block,
|
const int num_rows_per_block,
|
||||||
const int group_size) {
|
const int group_size) {
|
||||||
// one block one warp
|
if (threadIdx.x == 0) {
|
||||||
const int lane_id = threadIdx.x % warpSize;
|
int gridx = 0;
|
||||||
int prev_offset = 0;
|
int index = 0;
|
||||||
|
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||||
// loop on warp tile:[base, base+32)
|
|
||||||
for (int base = 0; base < bsz; base += warpSize) {
|
|
||||||
const int bid = base + lane_id;
|
|
||||||
|
|
||||||
// calculate loop_times for bid
|
|
||||||
int loop_times = 0;
|
|
||||||
if (bid < bsz) {
|
|
||||||
int seq_len = seq_lens_q[bid];
|
int seq_len = seq_lens_q[bid];
|
||||||
if (seq_lens_encoder && seq_lens_encoder[bid] > 0) {
|
if (seq_lens_encoder && seq_lens_encoder[bid] > 0) {
|
||||||
seq_len = 0;
|
seq_len = 0;
|
||||||
}
|
}
|
||||||
loop_times = div_up(seq_len * group_size, num_rows_per_block);
|
const int loop_times = div_up(seq_len * group_size, num_rows_per_block);
|
||||||
}
|
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
|
||||||
|
batch_ids[index] = bid;
|
||||||
// prefix sum for each lane, get the start offset in this tile
|
tile_ids_per_batch[index++] = tile_id;
|
||||||
// inclusive scan
|
|
||||||
int x = loop_times;
|
|
||||||
for (int offset = 1; offset < warpSize; offset <<= 1) {
|
|
||||||
int y = __shfl_up_sync(0xffffffff, x, offset);
|
|
||||||
if (lane_id >= offset) x += y;
|
|
||||||
}
|
|
||||||
// exclusive prefix sum
|
|
||||||
int bid_offset = x - loop_times;
|
|
||||||
int tile_sum = __shfl_sync(0xffffffff, x, warpSize - 1);
|
|
||||||
|
|
||||||
// write batch_ids and tile_ids_per_batch
|
|
||||||
if (bid < bsz && loop_times > 0) {
|
|
||||||
int write_base = prev_offset + bid_offset;
|
|
||||||
for (int t = 0; t < loop_times; ++t) {
|
|
||||||
int pos = write_base + t;
|
|
||||||
batch_ids[pos] = bid;
|
|
||||||
tile_ids_per_batch[pos] = t;
|
|
||||||
}
|
}
|
||||||
|
gridx += loop_times;
|
||||||
}
|
}
|
||||||
|
*num_blocks_x = gridx;
|
||||||
// for next warp tile
|
|
||||||
prev_offset += tile_sum;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
*num_blocks_x = prev_offset;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -277,22 +168,37 @@ __global__ void split_kv_block(const int *__restrict__ seq_lens_decoder,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetBlockShapeAndSplitKVBlock(
|
template <int THREADBLOCK_SIZE>
|
||||||
|
__global__ void
|
||||||
|
get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
|
||||||
|
const int *seq_lens_decoder, const int batch_size) {
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
|
||||||
|
typedef cub::BlockReduce<int, THREADBLOCK_SIZE> BlockReduce;
|
||||||
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||||
|
|
||||||
|
int max_len_this_thread = 0;
|
||||||
|
for (int i = tid; i < batch_size; i += blockDim.x) {
|
||||||
|
if (seq_lens_decoder[i] == 0)
|
||||||
|
continue;
|
||||||
|
max_len_this_thread =
|
||||||
|
max(seq_lens_this_time[i] + seq_lens_decoder[i], max_len_this_thread);
|
||||||
|
}
|
||||||
|
int total =
|
||||||
|
BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
|
||||||
|
if (tid == 0) {
|
||||||
|
*max_seq_lens_out = total;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||||
const paddle::Tensor &seq_lens_encoder,
|
const paddle::Tensor &seq_lens_encoder,
|
||||||
const paddle::Tensor &seq_lens_decoder,
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
const paddle::Tensor &seq_lens_this_time,
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||||
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
|
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||||
paddle::Tensor &decoder_num_blocks_device, // Inplace
|
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||||
paddle::Tensor &decoder_chunk_size_device, // Inplace
|
|
||||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
|
|
||||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
|
||||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
|
||||||
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU
|
|
||||||
paddle::Tensor &kv_batch_ids, // Inplace
|
|
||||||
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
|
|
||||||
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
|
|
||||||
const int encoder_block_shape_q,
|
const int encoder_block_shape_q,
|
||||||
const int decoder_block_shape_q,
|
const int decoder_block_shape_q,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
@@ -316,126 +222,32 @@ void GetBlockShapeAndSplitKVBlock(
|
|||||||
int max_just_dec_merged_len_this_time = max_len_cpu_ptr[5];
|
int max_just_dec_merged_len_this_time = max_len_cpu_ptr[5];
|
||||||
int max_system_len = max_len_cpu_ptr[6];
|
int max_system_len = max_len_cpu_ptr[6];
|
||||||
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
|
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
|
||||||
int max_kv_len_this_time = max_len_cpu_ptr[8];
|
|
||||||
|
|
||||||
// decoder
|
paddle::Tensor encoder_batch_ids;
|
||||||
if (max_dec_len_this_time > 0) {
|
paddle::Tensor encoder_tile_ids_per_batch;
|
||||||
|
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
|
||||||
|
paddle::Tensor kv_batch_ids;
|
||||||
|
paddle::Tensor kv_tile_ids_per_batch;
|
||||||
|
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
|
||||||
|
paddle::Tensor max_len_kv_cpu; /*cpu*/
|
||||||
|
|
||||||
const bool mla_backend = checkAttentionBackend();
|
auto max_len_kv =
|
||||||
if (mla_backend && group_size <= 64) {
|
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
|
||||||
const int set_chunk_size = get_mla_dec_chunk_size(bsz);
|
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
|
||||||
|
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
|
||||||
|
seq_lens_decoder.data<int>(), bsz);
|
||||||
|
|
||||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false);
|
||||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
|
||||||
|
|
||||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
|
||||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
|
||||||
|
|
||||||
int device;
|
|
||||||
cudaGetDevice(&device);
|
|
||||||
int sm_cout;
|
|
||||||
cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device);
|
|
||||||
constexpr int config_size =
|
|
||||||
12; // search space for chunk size:[64, 128, 256, ... 131072]
|
|
||||||
|
|
||||||
search_chunk_size_for_mla<config_size>
|
|
||||||
<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
|
|
||||||
seq_lens_encoder.data<int>(),
|
|
||||||
seq_lens_decoder.data<int>(),
|
|
||||||
decoder_num_blocks_device.data<int>(),
|
|
||||||
decoder_chunk_size_device.data<int>(),
|
|
||||||
bsz,
|
|
||||||
set_chunk_size,
|
|
||||||
block_size,
|
|
||||||
sm_cout);
|
|
||||||
|
|
||||||
decoder_num_blocks_cpu.copy_(
|
|
||||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
|
||||||
auto decoder_chunk_size_cpu =
|
|
||||||
decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false);
|
|
||||||
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
|
|
||||||
|
|
||||||
// NOTE: (changwenbin) When using auto_chunk,
|
|
||||||
// decode_max_tile_size must take into account the maximum case, where * 1024 can cover 128K.
|
|
||||||
// const uint32_t decoder_batch_shape = seq_lens_decoder.dims()[0] * 1024;
|
|
||||||
|
|
||||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
|
||||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
|
||||||
const uint32_t decoder_batch_shape =
|
|
||||||
bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
|
||||||
|
|
||||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
|
||||||
cudaMemsetAsync(decoder_batch_ids.data<int>(),
|
|
||||||
0,
|
|
||||||
decoder_batch_shape * sizeof(int32_t),
|
|
||||||
stream));
|
|
||||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
|
||||||
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
|
|
||||||
0,
|
|
||||||
decoder_batch_shape * sizeof(int32_t),
|
|
||||||
stream));
|
|
||||||
|
|
||||||
|
|
||||||
split_block_for_mla<<<1, 32, 0, stream>>>(
|
|
||||||
seq_lens_this_time.data<int>(),
|
|
||||||
seq_lens_encoder.data<int>(),
|
|
||||||
seq_lens_decoder.data<int>(),
|
|
||||||
decoder_batch_ids.data<int>(),
|
|
||||||
decoder_tile_ids_per_batch.data<int>(),
|
|
||||||
bsz,
|
|
||||||
chunk_size);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
// Note:(changwenbin)In order to adapt to cudagraph, the maximum value
|
|
||||||
// should be taken here
|
|
||||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
|
||||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
|
||||||
const uint32_t decoder_batch_shape =
|
|
||||||
bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
|
||||||
|
|
||||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
|
||||||
cudaMemsetAsync(decoder_batch_ids.data<int>(),
|
|
||||||
0,
|
|
||||||
decoder_batch_shape * sizeof(int32_t),
|
|
||||||
stream));
|
|
||||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
|
||||||
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
|
|
||||||
0,
|
|
||||||
decoder_batch_shape * sizeof(int32_t),
|
|
||||||
stream));
|
|
||||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
|
||||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
|
||||||
|
|
||||||
split_q_block<<<1, 32, 0, stream>>>(
|
|
||||||
seq_lens_this_time.data<int>(),
|
|
||||||
seq_lens_encoder.data<int>(),
|
|
||||||
decoder_batch_ids.data<int>(),
|
|
||||||
decoder_tile_ids_per_batch.data<int>(),
|
|
||||||
decoder_num_blocks_device.data<int>(),
|
|
||||||
bsz,
|
|
||||||
decoder_block_shape_q,
|
|
||||||
group_size);
|
|
||||||
|
|
||||||
decoder_num_blocks_cpu.copy_(
|
|
||||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
|
||||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
|
||||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
|
||||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
|
||||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
|
||||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
|
||||||
decoder_num_blocks_cpu.copy_(
|
|
||||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
|
||||||
}
|
|
||||||
|
|
||||||
// encoder
|
|
||||||
if (max_enc_len_this_time > 0) {
|
if (max_enc_len_this_time > 0) {
|
||||||
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
|
const uint32_t max_tile_size_per_bs_kv =
|
||||||
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
|
div_up(max_enc_dec_len_this_time, block_size);
|
||||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
|
kv_batch_ids =
|
||||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
|
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
|
||||||
|
seq_lens_encoder.place());
|
||||||
|
kv_tile_ids_per_batch =
|
||||||
|
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
|
||||||
|
seq_lens_encoder.place());
|
||||||
auto kv_num_blocks_x =
|
auto kv_num_blocks_x =
|
||||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||||
|
|
||||||
@@ -446,12 +258,16 @@ void GetBlockShapeAndSplitKVBlock(
|
|||||||
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
|
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
|
||||||
block_size, block_size);
|
block_size, block_size);
|
||||||
|
|
||||||
kv_num_blocks_x_cpu.copy_(kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false);
|
kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||||
// Clear buffer
|
|
||||||
const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
const uint32_t encoder_max_tile_size_per_bs_q =
|
||||||
const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q;
|
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
||||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
|
encoder_batch_ids =
|
||||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
|
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
|
||||||
|
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||||
|
encoder_tile_ids_per_batch =
|
||||||
|
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
|
||||||
|
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||||
auto encoder_num_blocks_x =
|
auto encoder_num_blocks_x =
|
||||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||||
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
|
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
|
||||||
@@ -459,35 +275,54 @@ void GetBlockShapeAndSplitKVBlock(
|
|||||||
encoder_tile_ids_per_batch.data<int>(),
|
encoder_tile_ids_per_batch.data<int>(),
|
||||||
encoder_num_blocks_x.data<int>(), bsz,
|
encoder_num_blocks_x.data<int>(), bsz,
|
||||||
encoder_block_shape_q, group_size);
|
encoder_block_shape_q, group_size);
|
||||||
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
|
encoder_num_blocks_x_cpu =
|
||||||
|
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||||
|
} else {
|
||||||
|
encoder_batch_ids =
|
||||||
|
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||||
|
encoder_tile_ids_per_batch =
|
||||||
|
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||||
|
encoder_num_blocks_x_cpu =
|
||||||
|
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||||
|
kv_batch_ids =
|
||||||
|
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||||
|
kv_tile_ids_per_batch =
|
||||||
|
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||||
|
kv_num_blocks_x_cpu =
|
||||||
|
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
if (max_just_dec_len_this_time > 0) {
|
||||||
|
// Clear buffer
|
||||||
|
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||||
|
const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q;
|
||||||
|
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||||
|
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||||
|
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
|
auto decoder_num_blocks_x =
|
||||||
const std::vector<int64_t> &seq_lens_encoder,
|
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||||
const std::vector<int64_t> &seq_lens_decoder,
|
split_q_block<<<1, 32, 0, stream>>>(
|
||||||
const std::vector<int64_t> &seq_lens_this_time,
|
seq_lens_this_time.data<int>(),
|
||||||
const int encoder_block_shape_q,
|
seq_lens_encoder.data<int>(),
|
||||||
const int decoder_block_shape_q,
|
decoder_batch_ids.data<int>(),
|
||||||
const int group_size,
|
decoder_tile_ids_per_batch.data<int>(),
|
||||||
const int block_size,
|
decoder_num_blocks_x.data<int>(),
|
||||||
const int decoder_step_token_num
|
bsz,
|
||||||
) {
|
decoder_block_shape_q,
|
||||||
return {};
|
group_size);
|
||||||
}
|
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
|
return {
|
||||||
const paddle::DataType &seq_lens_encoder,
|
encoder_batch_ids,
|
||||||
const paddle::DataType &seq_lens_decoder,
|
encoder_tile_ids_per_batch,
|
||||||
const paddle::DataType &seq_lens_this_time,
|
encoder_num_blocks_x_cpu, /*cpu*/
|
||||||
const int encoder_block_shape_q,
|
kv_batch_ids,
|
||||||
const int decoder_block_shape_q,
|
kv_tile_ids_per_batch,
|
||||||
const int group_size,
|
kv_num_blocks_x_cpu, /*cpu*/
|
||||||
const int block_size,
|
max_len_kv_cpu, /*cpu*/
|
||||||
const int decoder_step_token_num
|
};
|
||||||
) {
|
|
||||||
return {};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||||
@@ -497,19 +332,17 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
|||||||
"seq_lens_this_time",
|
"seq_lens_this_time",
|
||||||
"decoder_batch_ids",
|
"decoder_batch_ids",
|
||||||
"decoder_tile_ids_per_batch",
|
"decoder_tile_ids_per_batch",
|
||||||
"decoder_num_blocks_cpu",
|
"decoder_num_blocks_x_cpu",
|
||||||
"decoder_num_blocks_device",
|
"max_len_tensor_cpu"
|
||||||
"decoder_chunk_size_device",
|
|
||||||
"max_len_tensor_cpu",
|
|
||||||
"encoder_batch_ids",
|
|
||||||
"encoder_tile_ids_per_batch",
|
|
||||||
"encoder_num_blocks_x_cpu",
|
|
||||||
"kv_batch_ids",
|
|
||||||
"kv_tile_ids_per_batch",
|
|
||||||
"kv_num_blocks_x_cpu",
|
|
||||||
})
|
})
|
||||||
.Outputs({
|
.Outputs({
|
||||||
|
paddle::Optional("encoder_batch_ids"),
|
||||||
|
paddle::Optional("encoder_tile_ids_per_batch"),
|
||||||
|
paddle::Optional("encoder_num_blocks_x_cpu"),
|
||||||
|
paddle::Optional("kv_batch_ids"),
|
||||||
|
paddle::Optional("kv_tile_ids_per_batch"),
|
||||||
|
paddle::Optional("kv_num_blocks_x_cpu"),
|
||||||
|
"max_len_kv_cpu"
|
||||||
})
|
})
|
||||||
.Attrs({
|
.Attrs({
|
||||||
"encoder_block_shape_q: int",
|
"encoder_block_shape_q: int",
|
||||||
@@ -518,6 +351,4 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
|||||||
"block_size: int",
|
"block_size: int",
|
||||||
"decoder_step_token_num: int"
|
"decoder_step_token_num: int"
|
||||||
})
|
})
|
||||||
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
|
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock));
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))
|
|
||||||
.SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype));
|
|
||||||
|
|||||||
@@ -217,7 +217,7 @@ __global__ void append_cache_kv_c16(
|
|||||||
|
|
||||||
// load k_smem 64 rows 128 cols
|
// load k_smem 64 rows 128 cols
|
||||||
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter
|
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
|
||||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||||
k_smem_offset_w =
|
k_smem_offset_w =
|
||||||
@@ -235,7 +235,7 @@ __global__ void append_cache_kv_c16(
|
|||||||
// deal k_smem 64 rows 128 cols
|
// deal k_smem 64 rows 128 cols
|
||||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||||
uint32_t row_idx = wid * 16 + tid / 4;
|
uint32_t row_idx = wid * 16 + tid / 4;
|
||||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter
|
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
|
||||||
uint32_t col_idx = fy * 16 + tid % 4 * 2;
|
uint32_t col_idx = fy * 16 + tid % 4 * 2;
|
||||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
|
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
|
||||||
// layout
|
// layout
|
||||||
@@ -278,7 +278,7 @@ __global__ void append_cache_kv_c16(
|
|||||||
|
|
||||||
// load v_smem 64 rows 128 cols
|
// load v_smem 64 rows 128 cols
|
||||||
for (int fz = 0; fz < 4; fz++) { // // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
for (int fz = 0; fz < 4; fz++) { // // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||||
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter
|
for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 noce, need 2 iter
|
||||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||||
v_smem_offset_w =
|
v_smem_offset_w =
|
||||||
@@ -296,7 +296,7 @@ __global__ void append_cache_kv_c16(
|
|||||||
// deal v_smem 64 rows 128 cols
|
// deal v_smem 64 rows 128 cols
|
||||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||||
uint32_t row_idx = wid * 16 + tid / 4;
|
uint32_t row_idx = wid * 16 + tid / 4;
|
||||||
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter
|
for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 noce, need 8 iter
|
||||||
uint32_t col_idx = fy * 16 + tid % 4 * 2;
|
uint32_t col_idx = fy * 16 + tid % 4 * 2;
|
||||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag);
|
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag);
|
||||||
// layout
|
// layout
|
||||||
@@ -400,7 +400,7 @@ __global__ void append_cache_kv_c8(
|
|||||||
|
|
||||||
// load v_smem 64 rows, 128 cols
|
// load v_smem 64 rows, 128 cols
|
||||||
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||||
for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 once, need 1 iter
|
for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 noce, need 1 iter
|
||||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||||
k_smem_offset_w =
|
k_smem_offset_w =
|
||||||
@@ -418,7 +418,7 @@ __global__ void append_cache_kv_c8(
|
|||||||
// deal k_smem 64 rows, 128 cols
|
// deal k_smem 64 rows, 128 cols
|
||||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||||
uint32_t row_idx = wid * 16 + tid / 4;
|
uint32_t row_idx = wid * 16 + tid / 4;
|
||||||
for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 once, need 4 iter
|
for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 noce, need 4 iter
|
||||||
uint32_t col_idx = fy * 32 + tid % 4 * 2;
|
uint32_t col_idx = fy * 32 + tid % 4 * 2;
|
||||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
|
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
|
||||||
// layout
|
// layout
|
||||||
@@ -466,7 +466,7 @@ __global__ void append_cache_kv_c8(
|
|||||||
tid % 4 * num_elems_per_128b<CacheT>();
|
tid % 4 * num_elems_per_128b<CacheT>();
|
||||||
// load v_smem 128 rows 64 cols
|
// load v_smem 128 rows 64 cols
|
||||||
for (int fy = 0; fy < 4; fy++) { // 8 rows pre warp once, 32 rows all 4 warps once, need 4 iter
|
for (int fy = 0; fy < 4; fy++) { // 8 rows pre warp once, 32 rows all 4 warps once, need 4 iter
|
||||||
for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 once, need 1 iter
|
for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 noce, need 1 iter
|
||||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||||
v_smem_offset_w =
|
v_smem_offset_w =
|
||||||
@@ -485,7 +485,7 @@ __global__ void append_cache_kv_c8(
|
|||||||
// deal v_smem 128 rows 64 cols
|
// deal v_smem 128 rows 64 cols
|
||||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||||
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
|
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
|
||||||
for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 once, need 2 iter
|
for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 noce, need 2 iter
|
||||||
uint32_t kv_idx = fz * 32 + tid % 4 * 2;
|
uint32_t kv_idx = fz * 32 + tid % 4 * 2;
|
||||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
|
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
|
||||||
// layout
|
// layout
|
||||||
@@ -614,7 +614,7 @@ __global__ void append_cache_kv_c4(
|
|||||||
|
|
||||||
// load k_smem 64 rows 128 cols
|
// load k_smem 64 rows 128 cols
|
||||||
for (int fz = 0; fz < 2; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
for (int fz = 0; fz < 2; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter
|
||||||
for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 once, need 1 iter
|
for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 noce, need 1 iter
|
||||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||||
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0);
|
||||||
k_smem_offset_w =
|
k_smem_offset_w =
|
||||||
@@ -632,7 +632,7 @@ __global__ void append_cache_kv_c4(
|
|||||||
// deal k_smem 64 rows 128 cols
|
// deal k_smem 64 rows 128 cols
|
||||||
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter
|
||||||
uint32_t row_idx = wid * 16 + tid / 4;
|
uint32_t row_idx = wid * 16 + tid / 4;
|
||||||
for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 once, need 2 iter
|
for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 noce, need 2 iter
|
||||||
uint32_t col_idx = fy * 64 + tid % 4 * 2;
|
uint32_t col_idx = fy * 64 + tid % 4 * 2;
|
||||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
|
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
|
||||||
|
|
||||||
@@ -685,7 +685,7 @@ __global__ void append_cache_kv_c4(
|
|||||||
tid % 2 * num_elems_per_128b<CacheT>();
|
tid % 2 * num_elems_per_128b<CacheT>();
|
||||||
// load v_smem 128 rows 64 rows
|
// load v_smem 128 rows 64 rows
|
||||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter
|
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
|
||||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||||
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0);
|
||||||
v_smem_offset_w =
|
v_smem_offset_w =
|
||||||
@@ -704,7 +704,7 @@ __global__ void append_cache_kv_c4(
|
|||||||
// deal v_smem 128 rows 64 cols
|
// deal v_smem 128 rows 64 cols
|
||||||
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter
|
||||||
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
|
uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4;
|
||||||
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter
|
for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 noce, need 1 iter
|
||||||
uint32_t kv_idx = fz * 64 + tid % 4 * 2;
|
uint32_t kv_idx = fz * 64 + tid % 4 * 2;
|
||||||
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
|
v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag);
|
||||||
// layout
|
// layout
|
||||||
@@ -1000,7 +1000,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
|||||||
stream,
|
stream,
|
||||||
const_cast<paddle::Tensor*>(&key_cache),
|
const_cast<paddle::Tensor*>(&key_cache),
|
||||||
const_cast<paddle::Tensor*>(&value_cache));
|
const_cast<paddle::Tensor*>(&value_cache));
|
||||||
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8") {
|
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
|
||||||
CascadeAppendWriteCacheKVC8QKV<data_t, 128, 64>(
|
CascadeAppendWriteCacheKVC8QKV<data_t, 128, 64>(
|
||||||
meta_data,
|
meta_data,
|
||||||
*const_cast<paddle::Tensor*>(&key_cache),
|
*const_cast<paddle::Tensor*>(&key_cache),
|
||||||
@@ -1018,7 +1018,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
|||||||
kv_num_blocks_data,
|
kv_num_blocks_data,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
false, // is_scale_channel_wise
|
false, // is_scale_channel_wise
|
||||||
cache_quant_type,
|
cache_quant_type == "cache_fp8", // is_fp8
|
||||||
stream,
|
stream,
|
||||||
const_cast<paddle::Tensor*>(&key_cache),
|
const_cast<paddle::Tensor*>(&key_cache),
|
||||||
const_cast<paddle::Tensor*>(&value_cache));
|
const_cast<paddle::Tensor*>(&value_cache));
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <cooperative_groups/memcpy_async.h>
|
|
||||||
|
|
||||||
enum class SharedMemFillMode { kFillZero, kNoFill };
|
enum class SharedMemFillMode { kFillZero, kNoFill };
|
||||||
|
|
||||||
@@ -43,35 +42,18 @@ __device__ __forceinline__ void ldmatrix_m8n8x4_trans_impl(uint32_t* R,
|
|||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ void commit_group() {
|
__device__ __forceinline__ void commit_group() {
|
||||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
|
||||||
{}
|
|
||||||
#else
|
|
||||||
asm volatile("cp.async.commit_group;\n" ::);
|
asm volatile("cp.async.commit_group;\n" ::);
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t n>
|
template <size_t n>
|
||||||
__device__ __forceinline__ void wait_group() {
|
__device__ __forceinline__ void wait_group() {
|
||||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
|
||||||
cooperative_groups::wait(cooperative_groups::this_thread_block());
|
|
||||||
#else
|
|
||||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <PrefetchMode prefetch_mode, typename T>
|
template <PrefetchMode prefetch_mode, typename T>
|
||||||
__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) {
|
__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) {
|
||||||
uint32_t smem_int_ptr =
|
uint32_t smem_int_ptr =
|
||||||
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
|
||||||
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
|
|
||||||
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
|
|
||||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
|
|
||||||
} else {
|
|
||||||
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
|
|
||||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
|
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
|
||||||
asm volatile(
|
asm volatile(
|
||||||
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(
|
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(
|
||||||
@@ -86,7 +68,6 @@ __device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) {
|
|||||||
"n"(16),
|
"n"(16),
|
||||||
"r"(16));
|
"r"(16));
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
|
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
|
||||||
@@ -95,28 +76,6 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr,
|
|||||||
bool predicate) {
|
bool predicate) {
|
||||||
uint32_t smem_int_ptr =
|
uint32_t smem_int_ptr =
|
||||||
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
|
||||||
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
|
|
||||||
int src_in_bytes = predicate ? 16 : 0;
|
|
||||||
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
|
|
||||||
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
|
|
||||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
|
|
||||||
} else {
|
|
||||||
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
|
|
||||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
|
|
||||||
if (predicate) {
|
|
||||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (predicate) {
|
|
||||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
|
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
|
||||||
int src_in_bytes = predicate ? 16 : 0;
|
int src_in_bytes = predicate ? 16 : 0;
|
||||||
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
|
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
|
||||||
@@ -156,7 +115,6 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr,
|
|||||||
"n"(16));
|
"n"(16));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
|
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
|
||||||
@@ -165,17 +123,6 @@ __device__ __forceinline__ void pred_load_64b(T* smem_ptr,
|
|||||||
bool predicate) {
|
bool predicate) {
|
||||||
uint32_t smem_int_ptr =
|
uint32_t smem_int_ptr =
|
||||||
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
|
||||||
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
|
|
||||||
int src_in_bytes = predicate ? 8 : 0;
|
|
||||||
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 8);
|
|
||||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
|
|
||||||
} else {
|
|
||||||
if (predicate) {
|
|
||||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 8);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
|
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
|
||||||
int src_in_bytes = predicate ? 8 : 0;
|
int src_in_bytes = predicate ? 8 : 0;
|
||||||
asm volatile(
|
asm volatile(
|
||||||
@@ -194,7 +141,6 @@ __device__ __forceinline__ void pred_load_64b(T* smem_ptr,
|
|||||||
"l"(gmem_ptr),
|
"l"(gmem_ptr),
|
||||||
"n"(8));
|
"n"(8));
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
|
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
|
||||||
@@ -203,17 +149,6 @@ __device__ __forceinline__ void pred_load_32b(T* smem_ptr,
|
|||||||
bool predicate) {
|
bool predicate) {
|
||||||
uint32_t smem_int_ptr =
|
uint32_t smem_int_ptr =
|
||||||
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
|
||||||
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
|
|
||||||
int src_in_bytes = predicate ? 4 : 0;
|
|
||||||
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 4);
|
|
||||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
|
|
||||||
} else {
|
|
||||||
if (predicate) {
|
|
||||||
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 4);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
|
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
|
||||||
int src_in_bytes = predicate ? 4 : 0;
|
int src_in_bytes = predicate ? 4 : 0;
|
||||||
asm volatile(
|
asm volatile(
|
||||||
@@ -232,7 +167,6 @@ __device__ __forceinline__ void pred_load_32b(T* smem_ptr,
|
|||||||
"l"(gmem_ptr),
|
"l"(gmem_ptr),
|
||||||
"n"(4));
|
"n"(4));
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t num_bits, PrefetchMode prefetch_mode, typename T>
|
template <size_t num_bits, PrefetchMode prefetch_mode, typename T>
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
@@ -12,94 +12,27 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
#include "utils.cuh"
|
#include "utils.cuh"
|
||||||
#include "multiquery_decoder_attention_impl.cuh"
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void DecodeMLAAttentionKernel(
|
void DecodeMLAAttentionKernel(
|
||||||
const AppendAttnMetaData& meta_data,
|
const AppendAttnMetaData& meta_data,
|
||||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
||||||
const paddle::Tensor &cache_k,
|
const paddle::Tensor &cache_k,
|
||||||
const paddle::Tensor &cache_v,
|
const paddle::Tensor &cache_v,
|
||||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
const paddle::optional<paddle::Tensor>& shift_bias,
|
||||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
const paddle::optional<paddle::Tensor>& smooth_weight,
|
||||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
||||||
const paddle::Tensor &seq_lens_kv,
|
const paddle::Tensor &seq_lens_kv,
|
||||||
const paddle::Tensor &batch_id_per_token,
|
const paddle::Tensor &batch_id_per_token,
|
||||||
const paddle::Tensor &cu_seqlens_q,
|
const paddle::Tensor &cu_seqlens_q,
|
||||||
const paddle::Tensor &block_table,
|
const paddle::Tensor &block_table,
|
||||||
int max_seq_len,
|
int max_seq_len,
|
||||||
int max_dec_len,
|
int max_dec_len,
|
||||||
float softmax_scale,
|
float softmax_scale,
|
||||||
float in_scale,
|
float in_scale,
|
||||||
bool causal,
|
bool causal,
|
||||||
cudaStream_t &stream,
|
cudaStream_t &stream,
|
||||||
paddle::Tensor *out) {
|
paddle::Tensor *out);
|
||||||
const auto token_num = meta_data.token_nums;
|
|
||||||
const auto block_size = meta_data.block_size;
|
|
||||||
const auto bsz = meta_data.batch_size;
|
|
||||||
const auto num_heads = meta_data.q_num_heads;
|
|
||||||
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
|
|
||||||
const auto head_dim_qk = meta_data.head_dims;
|
|
||||||
const auto head_dim_v = meta_data.head_dims_v;
|
|
||||||
const float rope_scale = 0.0;
|
|
||||||
const float rope_theta = 0.0;
|
|
||||||
const uint32_t deal_each_time = get_cascade_attention_deal_each_time();
|
|
||||||
const uint32_t num_stage = get_cascade_attention_num_stages();
|
|
||||||
const uint32_t num_threads = get_cascade_attention_num_threads();
|
|
||||||
|
|
||||||
DISPATCH_CAUSAL(causal, CAUSAL,
|
|
||||||
{DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE,
|
|
||||||
{DISPATCH_MLA_HEAD_DIM(head_dim_qk, HEAD_DIM_QK,
|
|
||||||
{DISPATCH_MLA_HEAD_DIM(head_dim_v, HEAD_DIM_V,
|
|
||||||
{DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE,
|
|
||||||
{DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME,
|
|
||||||
{MultiQueryDecoderAttention<T, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, CAUSAL, 2, 16, DEAL_EACH_TIME>(
|
|
||||||
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, batch_id_per_token, cu_seqlens_q,
|
|
||||||
block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})});
|
|
||||||
}
|
|
||||||
|
|
||||||
template void DecodeMLAAttentionKernel<paddle::bfloat16>(
|
|
||||||
const AppendAttnMetaData& meta_data,
|
|
||||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
|
||||||
const paddle::Tensor &cache_k,
|
|
||||||
const paddle::Tensor &cache_v,
|
|
||||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
|
||||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
|
||||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
|
||||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
|
||||||
const paddle::Tensor &seq_lens_kv,
|
|
||||||
const paddle::Tensor &batch_id_per_token,
|
|
||||||
const paddle::Tensor &cu_seqlens_q,
|
|
||||||
const paddle::Tensor &block_table,
|
|
||||||
int max_seq_len,
|
|
||||||
int max_dec_len,
|
|
||||||
float softmax_scale,
|
|
||||||
float in_scale,
|
|
||||||
bool causal,
|
|
||||||
cudaStream_t &stream,
|
|
||||||
paddle::Tensor *out);
|
|
||||||
|
|
||||||
template void DecodeMLAAttentionKernel<paddle::float16>(
|
|
||||||
const AppendAttnMetaData& meta_data,
|
|
||||||
const paddle::Tensor &q, // [token_num, num_heads, head_dim]
|
|
||||||
const paddle::Tensor &cache_k,
|
|
||||||
const paddle::Tensor &cache_v,
|
|
||||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
|
||||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
|
||||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
|
||||||
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
|
|
||||||
const paddle::Tensor &seq_lens_kv,
|
|
||||||
const paddle::Tensor &batch_id_per_token,
|
|
||||||
const paddle::Tensor &cu_seqlens_q,
|
|
||||||
const paddle::Tensor &block_table,
|
|
||||||
int max_seq_len,
|
|
||||||
int max_dec_len,
|
|
||||||
float softmax_scale,
|
|
||||||
float in_scale,
|
|
||||||
bool causal,
|
|
||||||
cudaStream_t &stream,
|
|
||||||
paddle::Tensor *out);
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,56 +0,0 @@
|
|||||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "append_attention_func.cuh"
|
|
||||||
|
|
||||||
template <typename T,
|
|
||||||
uint32_t GROUP_SIZE,
|
|
||||||
uint32_t HEAD_DIM,
|
|
||||||
uint32_t BLOCK_SIZE,
|
|
||||||
bool CAUSAL,
|
|
||||||
uint32_t BLOCK_SHAPE_Q,
|
|
||||||
uint32_t NUM_WARP_Q,
|
|
||||||
typename OutT,
|
|
||||||
bool ENABLE_PREFILL = true>
|
|
||||||
void MultiQueryAppendAttention(
|
|
||||||
const AppendAttnMetaData &meta_data,
|
|
||||||
const paddle::Tensor &qkv,
|
|
||||||
const paddle::Tensor &cache_k,
|
|
||||||
const paddle::Tensor &cache_v,
|
|
||||||
const paddle::optional<paddle::Tensor> &attn_mask,
|
|
||||||
const paddle::optional<paddle::Tensor> &shift_bias,
|
|
||||||
const paddle::optional<paddle::Tensor> &smooth_weight,
|
|
||||||
const paddle::optional<paddle::Tensor> &sinks,
|
|
||||||
const paddle::Tensor &seq_lens_q,
|
|
||||||
const paddle::Tensor &seq_lens_kv,
|
|
||||||
const paddle::Tensor &seq_lens_encoder,
|
|
||||||
const paddle::Tensor &batch_id_per_token,
|
|
||||||
const paddle::Tensor &cu_seqlens_q,
|
|
||||||
const paddle::Tensor &block_table,
|
|
||||||
const paddle::Tensor &batch_ids,
|
|
||||||
const paddle::Tensor &tile_ids_per_batch,
|
|
||||||
const int num_blocks_x_cpu,
|
|
||||||
const int max_seq_len,
|
|
||||||
const int max_dec_len,
|
|
||||||
const float quant_max_bound,
|
|
||||||
const float quant_min_bound,
|
|
||||||
const float in_scale,
|
|
||||||
const int max_partition_size,
|
|
||||||
const int encoder_max_partition_size,
|
|
||||||
const int speculate_max_draft_token_num,
|
|
||||||
const bool is_decoder,
|
|
||||||
cudaStream_t &stream,
|
|
||||||
paddle::Tensor *out,
|
|
||||||
const int sliding_window);
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,60 +0,0 @@
|
|||||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "append_attention_func.cuh"
|
|
||||||
|
|
||||||
template <typename T,
|
|
||||||
uint32_t GROUP_SIZE,
|
|
||||||
uint32_t HEAD_DIM,
|
|
||||||
uint32_t BLOCK_SIZE,
|
|
||||||
bool CAUSAL,
|
|
||||||
uint32_t BLOCK_SHAPE_Q,
|
|
||||||
uint32_t NUM_WARP_Q,
|
|
||||||
typename OutT = T,
|
|
||||||
bool ENABLE_PREFILL = true>
|
|
||||||
void MultiQueryAppendC4Attention(
|
|
||||||
const AppendAttnMetaData &meta_data,
|
|
||||||
const paddle::Tensor &qkv,
|
|
||||||
const paddle::Tensor &cache_k,
|
|
||||||
const paddle::Tensor &cache_v,
|
|
||||||
const paddle::optional<paddle::Tensor> &attn_mask,
|
|
||||||
const paddle::Tensor &cache_k_scale,
|
|
||||||
const paddle::Tensor &cache_v_scale,
|
|
||||||
const paddle::optional<paddle::Tensor> &cache_k_zp,
|
|
||||||
const paddle::optional<paddle::Tensor> &cache_v_zp,
|
|
||||||
const paddle::optional<paddle::Tensor> &shift_bias,
|
|
||||||
const paddle::optional<paddle::Tensor> &smooth_weight,
|
|
||||||
const paddle::optional<paddle::Tensor> &sinks,
|
|
||||||
const paddle::Tensor &seq_lens_q,
|
|
||||||
const paddle::Tensor &seq_lens_kv,
|
|
||||||
const paddle::Tensor &seq_lens_encoder,
|
|
||||||
const paddle::Tensor &batch_id_per_token,
|
|
||||||
const paddle::Tensor &cu_seqlens_q,
|
|
||||||
const paddle::Tensor &block_table,
|
|
||||||
const paddle::Tensor &batch_ids,
|
|
||||||
const paddle::Tensor &tile_ids_per_batch,
|
|
||||||
const int num_blocks_x_cpu,
|
|
||||||
const int max_seq_len,
|
|
||||||
const int max_dec_len,
|
|
||||||
const float quant_max_bound,
|
|
||||||
const float quant_min_bound,
|
|
||||||
const float in_scale,
|
|
||||||
const int max_partition_size,
|
|
||||||
const int encoder_max_partition_size,
|
|
||||||
const int speculate_max_draft_token_num,
|
|
||||||
const bool is_decoder,
|
|
||||||
cudaStream_t &stream,
|
|
||||||
paddle::Tensor *out,
|
|
||||||
const int sliding_window);
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,60 +0,0 @@
|
|||||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "append_attention_func.cuh"
|
|
||||||
|
|
||||||
template <typename T,
|
|
||||||
uint32_t GROUP_SIZE,
|
|
||||||
uint32_t HEAD_DIM,
|
|
||||||
uint32_t BLOCK_SIZE,
|
|
||||||
bool CAUSAL,
|
|
||||||
uint32_t BLOCK_SHAPE_Q,
|
|
||||||
uint32_t NUM_WARP_Q,
|
|
||||||
typename OutT = T,
|
|
||||||
bool ENABLE_PREFILL = true,
|
|
||||||
bool IsFP8 = false,
|
|
||||||
bool IsDynamicC8 = false>
|
|
||||||
void MultiQueryAppendC8Attention(
|
|
||||||
const AppendAttnMetaData &meta_data,
|
|
||||||
const paddle::Tensor &qkv,
|
|
||||||
const paddle::Tensor &cache_k,
|
|
||||||
const paddle::Tensor &cache_v,
|
|
||||||
const paddle::optional<paddle::Tensor> &attn_mask,
|
|
||||||
const paddle::Tensor &cache_k_scale,
|
|
||||||
const paddle::Tensor &cache_v_scale,
|
|
||||||
const paddle::optional<paddle::Tensor> &shift_bias,
|
|
||||||
const paddle::optional<paddle::Tensor> &smooth_weight,
|
|
||||||
const paddle::optional<paddle::Tensor> &sinks,
|
|
||||||
const paddle::Tensor &seq_lens_q,
|
|
||||||
const paddle::Tensor &seq_lens_kv,
|
|
||||||
const paddle::Tensor &seq_lens_encoder,
|
|
||||||
const paddle::Tensor &batch_id_per_token,
|
|
||||||
const paddle::Tensor &cu_seqlens_q,
|
|
||||||
const paddle::Tensor &block_table,
|
|
||||||
const paddle::Tensor &batch_ids,
|
|
||||||
const paddle::Tensor &tile_ids_per_batch,
|
|
||||||
const int num_blocks_x_cpu,
|
|
||||||
const int max_seq_len,
|
|
||||||
const int max_dec_len,
|
|
||||||
const float quant_max_bound,
|
|
||||||
const float quant_min_bound,
|
|
||||||
const float in_scale,
|
|
||||||
const int max_partition_size,
|
|
||||||
const int encoder_max_partition_size,
|
|
||||||
const int speculate_max_draft_token_num,
|
|
||||||
const bool is_decoder,
|
|
||||||
cudaStream_t &stream,
|
|
||||||
paddle::Tensor *out,
|
|
||||||
const int sliding_window);
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "decode_attention_func.cuh"
|
|
||||||
|
|
||||||
template <typename T, uint32_t GROUP_SIZE, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_V, uint32_t BLOCK_SIZE, bool CAUSAL, uint32_t NUM_STAGE, uint32_t cache_bytes, uint32_t DEAL_EACH_TIME>
|
|
||||||
void MultiQueryDecoderAttention(
|
|
||||||
const AppendAttnMetaData& meta_data,
|
|
||||||
cudaStream_t &stream,
|
|
||||||
const paddle::Tensor &q,
|
|
||||||
const paddle::Tensor &cache_k, // [max_block_num, num_kv_heads, block_size, head_dim]
|
|
||||||
const paddle::Tensor &cache_v, // [num_kv_heads, head_dim]
|
|
||||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
|
||||||
const paddle::optional<paddle::Tensor>& shift_bias,
|
|
||||||
const paddle::optional<paddle::Tensor>& smooth_weight,
|
|
||||||
const paddle::Tensor &seq_lens_q,
|
|
||||||
const paddle::Tensor &seq_lens_kv,
|
|
||||||
const paddle::Tensor &batch_id_per_token,
|
|
||||||
const paddle::Tensor &cu_seqlens_q,
|
|
||||||
const paddle::Tensor &block_table,
|
|
||||||
const int max_seq_len,
|
|
||||||
const int max_dec_len,
|
|
||||||
const float rope_scale,
|
|
||||||
const float rope_theta,
|
|
||||||
const float softmax_scale,
|
|
||||||
const float in_scale,
|
|
||||||
paddle::Tensor *out);
|
|
||||||
@@ -18,167 +18,6 @@
|
|||||||
#include "mma_tensor_op.cuh"
|
#include "mma_tensor_op.cuh"
|
||||||
#include "utils.cuh"
|
#include "utils.cuh"
|
||||||
|
|
||||||
template <typename T, int VecSize = 1, typename InT = T>
|
|
||||||
__global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
|
||||||
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
|
|
||||||
// head_size]
|
|
||||||
T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size,
|
|
||||||
// head_size // 2]
|
|
||||||
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
|
|
||||||
// head_size // 2]
|
|
||||||
T* __restrict__ q_out,
|
|
||||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
|
||||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
|
||||||
const int* __restrict__ cu_seqlens_q,
|
|
||||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
|
||||||
const float* __restrict__ cos_emb,
|
|
||||||
const float* __restrict__ sin_emb,
|
|
||||||
const float*
|
|
||||||
qkv_out_scales, // [(num_heads + 2 * gqa_group_size) * head_size]
|
|
||||||
const T* qkv_biases, // [num_head + 2 * gqa_group_size, dim_head]
|
|
||||||
const int max_seq_len,
|
|
||||||
const int max_blocks_per_seq,
|
|
||||||
const int num_heads,
|
|
||||||
const int output_inner_dim,
|
|
||||||
const int head_size,
|
|
||||||
const int block_size,
|
|
||||||
const int elem_cnt,
|
|
||||||
const int gqa_group_size,
|
|
||||||
const float* q_norm_weight,
|
|
||||||
const float* k_norm_weight,
|
|
||||||
const float rms_norm_eps,
|
|
||||||
const bool rope_3d) {
|
|
||||||
using LoadT = AlignedVector<T, VecSize>;
|
|
||||||
using LoadFloat = AlignedVector<float, VecSize>;
|
|
||||||
using LoadInT = AlignedVector<InT, VecSize>;
|
|
||||||
constexpr int HalfVecSize = VecSize / 2;
|
|
||||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
|
||||||
LoadInT src_vec;
|
|
||||||
LoadFloat scale_vec;
|
|
||||||
LoadT bias_vec;
|
|
||||||
LoadEmbT cos_emb_vec;
|
|
||||||
LoadEmbT sin_emb_vec;
|
|
||||||
LoadFloat tmp_vec;
|
|
||||||
LoadFloat q_norm_vec;
|
|
||||||
LoadFloat k_norm_vec;
|
|
||||||
|
|
||||||
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
|
||||||
int64_t all_warp_num = gridDim.x * blockDim.y;
|
|
||||||
int64_t all_head_dim = elem_cnt / head_size;
|
|
||||||
|
|
||||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size;
|
|
||||||
const int half_head_size = head_size / 2;
|
|
||||||
for (int global_hi = global_warp_idx; global_hi < all_head_dim;
|
|
||||||
global_hi += all_warp_num) {
|
|
||||||
int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize;
|
|
||||||
const int token_id = linear_index / hidden_size;
|
|
||||||
|
|
||||||
const int ori_bi = batch_id_per_token[token_id];
|
|
||||||
if (ori_bi == -1) continue; // NOTE(gongshaotian): For CUDAGraph padding
|
|
||||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
|
||||||
const int bias = linear_index % hidden_size;
|
|
||||||
const int hi = bias / head_size; // q + k + v
|
|
||||||
const int h_bias = bias % head_size;
|
|
||||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
|
||||||
const int write_seq_id =
|
|
||||||
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
|
|
||||||
if (write_seq_id == 0) continue;
|
|
||||||
|
|
||||||
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
|
||||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
|
||||||
if (block_idx < 0) {
|
|
||||||
return; // NOTE(gongshaotian): For CUDAGraph padding
|
|
||||||
}
|
|
||||||
const int block_offset = write_seq_id % block_size;
|
|
||||||
|
|
||||||
const int write_q_idx =
|
|
||||||
token_id * output_inner_dim * head_size + hi * head_size + h_bias;
|
|
||||||
|
|
||||||
const int bias_idx = hi * head_size + h_bias;
|
|
||||||
Load<InT, VecSize>(&qkv[linear_index], &src_vec);
|
|
||||||
if (qkv_biases) {
|
|
||||||
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
|
|
||||||
}
|
|
||||||
if (qkv_out_scales) {
|
|
||||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &scale_vec);
|
|
||||||
}
|
|
||||||
if (hi < num_heads + gqa_group_size) {
|
|
||||||
// q k rope
|
|
||||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
|
||||||
uint32_t new_emb_idx =
|
|
||||||
rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
|
||||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
|
||||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
|
||||||
}
|
|
||||||
float thread_m2 = 0.0f;
|
|
||||||
float warp_m2 = 0.0f;
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < HalfVecSize; i++) {
|
|
||||||
// add_bias + rope
|
|
||||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
|
||||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
|
||||||
if (qkv_out_scales) {
|
|
||||||
input_left *= scale_vec[2 * i];
|
|
||||||
input_right *= scale_vec[2 * i + 1];
|
|
||||||
}
|
|
||||||
if (qkv_biases) {
|
|
||||||
input_left = input_left + static_cast<float>(bias_vec[2 * i]);
|
|
||||||
input_right = input_right + static_cast<float>(bias_vec[2 * i + 1]);
|
|
||||||
}
|
|
||||||
if (hi < num_heads + gqa_group_size) {
|
|
||||||
const float cos_tmp = cos_emb_vec[i];
|
|
||||||
const float sin_tmp = sin_emb_vec[i];
|
|
||||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
|
||||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
|
||||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
|
||||||
tmp_vec[2 * i] = tmp1;
|
|
||||||
tmp_vec[2 * i + 1] = tmp2;
|
|
||||||
} else {
|
|
||||||
bias_vec[2 * i] = static_cast<T>(input_left);
|
|
||||||
bias_vec[2 * i + 1] = static_cast<T>(input_right);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (hi < (num_heads + gqa_group_size)) {
|
|
||||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
|
||||||
float row_variance = max(warp_m2 / head_size, 0.0f);
|
|
||||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
|
||||||
if (hi < num_heads) {
|
|
||||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize],
|
|
||||||
&q_norm_vec);
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < VecSize; i++) {
|
|
||||||
bias_vec[i] =
|
|
||||||
static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize],
|
|
||||||
&k_norm_vec);
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < VecSize; i++) {
|
|
||||||
bias_vec[i] =
|
|
||||||
static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (hi < num_heads) {
|
|
||||||
// write q
|
|
||||||
Store<T, VecSize>(bias_vec, &q_out[write_q_idx]);
|
|
||||||
} else {
|
|
||||||
// write k/v
|
|
||||||
const int kv_head_idx = (hi - num_heads) % gqa_group_size;
|
|
||||||
const int tgt_idx = (block_idx * gqa_group_size * block_size * head_size +
|
|
||||||
kv_head_idx * block_size * head_size +
|
|
||||||
block_offset * head_size + h_bias);
|
|
||||||
// write
|
|
||||||
if (hi < num_heads + gqa_group_size) {
|
|
||||||
Store<T, VecSize>(bias_vec, &key_cache[tgt_idx]);
|
|
||||||
} else {
|
|
||||||
Store<T, VecSize>(bias_vec, &value_cache[tgt_idx]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int VecSize = 4, int HeadDim = 128>
|
template <int VecSize = 4, int HeadDim = 128>
|
||||||
__global__ void append_clear_cache_int8_block(
|
__global__ void append_clear_cache_int8_block(
|
||||||
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
|
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
|
||||||
@@ -186,7 +25,7 @@ __global__ void append_clear_cache_int8_block(
|
|||||||
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
||||||
// block_size, head_size // 2]
|
// block_size, head_size // 2]
|
||||||
const int* __restrict__ seq_lens,
|
const int* __restrict__ seq_lens,
|
||||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||||
const int* __restrict__ cu_seqlens_q,
|
const int* __restrict__ cu_seqlens_q,
|
||||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||||
@@ -204,7 +43,6 @@ __global__ void append_clear_cache_int8_block(
|
|||||||
const int token_id = blockIdx.x;
|
const int token_id = blockIdx.x;
|
||||||
|
|
||||||
const int bid = batch_id_per_token[token_id];
|
const int bid = batch_id_per_token[token_id];
|
||||||
if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
|
||||||
|
|
||||||
const int start_token_idx = cu_seqlens_q[bid];
|
const int start_token_idx = cu_seqlens_q[bid];
|
||||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||||
@@ -253,6 +91,7 @@ __global__ void append_clear_cache_int8_block(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <int VecSize = 4, int HeadDim = 128>
|
template <int VecSize = 4, int HeadDim = 128>
|
||||||
__global__ void append_clear_cache_int4_block(
|
__global__ void append_clear_cache_int4_block(
|
||||||
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
|
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
|
||||||
@@ -260,7 +99,7 @@ __global__ void append_clear_cache_int4_block(
|
|||||||
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
||||||
// block_size, head_size // 2]
|
// block_size, head_size // 2]
|
||||||
const int* __restrict__ seq_lens,
|
const int* __restrict__ seq_lens,
|
||||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||||
const int* __restrict__ cu_seqlens_q,
|
const int* __restrict__ cu_seqlens_q,
|
||||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||||
@@ -278,7 +117,6 @@ __global__ void append_clear_cache_int4_block(
|
|||||||
const int token_id = blockIdx.x;
|
const int token_id = blockIdx.x;
|
||||||
|
|
||||||
const int bid = batch_id_per_token[token_id];
|
const int bid = batch_id_per_token[token_id];
|
||||||
if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
|
||||||
|
|
||||||
const int start_token_idx = cu_seqlens_q[bid];
|
const int start_token_idx = cu_seqlens_q[bid];
|
||||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||||
@@ -339,7 +177,7 @@ __global__ void append_speculate_cache_rope_kernel(
|
|||||||
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
|
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
|
||||||
// head_size // 2]
|
// head_size // 2]
|
||||||
T* __restrict__ q_out,
|
T* __restrict__ q_out,
|
||||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||||
const int* __restrict__ cu_seqlens_q,
|
const int* __restrict__ cu_seqlens_q,
|
||||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||||
@@ -355,8 +193,7 @@ __global__ void append_speculate_cache_rope_kernel(
|
|||||||
const int head_size,
|
const int head_size,
|
||||||
const int block_size,
|
const int block_size,
|
||||||
const int elem_cnt,
|
const int elem_cnt,
|
||||||
const int gqa_group_size,
|
const int gqa_group_size) {
|
||||||
const bool rope_3d) {
|
|
||||||
using LoadT = AlignedVector<T, VecSize>;
|
using LoadT = AlignedVector<T, VecSize>;
|
||||||
using LoadFloat = AlignedVector<float, VecSize>;
|
using LoadFloat = AlignedVector<float, VecSize>;
|
||||||
using LoadInT = AlignedVector<InT, VecSize>;
|
using LoadInT = AlignedVector<InT, VecSize>;
|
||||||
@@ -378,8 +215,6 @@ __global__ void append_speculate_cache_rope_kernel(
|
|||||||
linear_index += step) {
|
linear_index += step) {
|
||||||
const int token_id = linear_index / hidden_size;
|
const int token_id = linear_index / hidden_size;
|
||||||
const int ori_bi = batch_id_per_token[token_id];
|
const int ori_bi = batch_id_per_token[token_id];
|
||||||
if (ori_bi == -1) continue; // NOTE(gongshaotian): For CUDAGraph padding
|
|
||||||
|
|
||||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||||
const int bias = linear_index % hidden_size;
|
const int bias = linear_index % hidden_size;
|
||||||
const int hi = bias / head_size; // q + k + v
|
const int hi = bias / head_size; // q + k + v
|
||||||
@@ -392,7 +227,15 @@ __global__ void append_speculate_cache_rope_kernel(
|
|||||||
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||||
if (block_idx < 0) {
|
if (block_idx < 0) {
|
||||||
return; // NOTE(gongshaotian): For CUDAGraph padding
|
printf(
|
||||||
|
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
|
||||||
|
"%d %d %d %d\n",
|
||||||
|
block_idx,
|
||||||
|
write_seq_id,
|
||||||
|
ori_bi,
|
||||||
|
seq_lens_decoder[ori_bi],
|
||||||
|
token_id,
|
||||||
|
cu_seqlens_q[ori_bi]);
|
||||||
}
|
}
|
||||||
const int block_offset = write_seq_id % block_size;
|
const int block_offset = write_seq_id % block_size;
|
||||||
|
|
||||||
@@ -410,10 +253,8 @@ __global__ void append_speculate_cache_rope_kernel(
|
|||||||
if (hi < num_heads + gqa_group_size) {
|
if (hi < num_heads + gqa_group_size) {
|
||||||
// q k rope
|
// q k rope
|
||||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||||
int64_t new_emb_idx =
|
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||||
rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
|
||||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < HalfVecSize; i++) {
|
for (int i = 0; i < HalfVecSize; i++) {
|
||||||
@@ -469,7 +310,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
|||||||
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
|
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
|
||||||
// head_size // 2]
|
// head_size // 2]
|
||||||
T* __restrict__ qkv_out,
|
T* __restrict__ qkv_out,
|
||||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||||
const int* __restrict__ cu_seqlens_q,
|
const int* __restrict__ cu_seqlens_q,
|
||||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||||
@@ -485,8 +326,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
|||||||
const int head_size,
|
const int head_size,
|
||||||
const int block_size,
|
const int block_size,
|
||||||
const int elem_cnt,
|
const int elem_cnt,
|
||||||
const int gqa_group_size,
|
const int gqa_group_size) {
|
||||||
const bool rope_3d) {
|
|
||||||
using LoadT = AlignedVector<T, VecSize>;
|
using LoadT = AlignedVector<T, VecSize>;
|
||||||
using LoadFloat = AlignedVector<float, VecSize>;
|
using LoadFloat = AlignedVector<float, VecSize>;
|
||||||
using LoadInT = AlignedVector<InT, VecSize>;
|
using LoadInT = AlignedVector<InT, VecSize>;
|
||||||
@@ -508,7 +348,6 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
|||||||
linear_index += step) {
|
linear_index += step) {
|
||||||
const int token_id = linear_index / half_hidden_size;
|
const int token_id = linear_index / half_hidden_size;
|
||||||
const int ori_bi = batch_id_per_token[token_id];
|
const int ori_bi = batch_id_per_token[token_id];
|
||||||
if (ori_bi == -1) continue; // NOTE(gongshaotian): For CUDAGraph padding
|
|
||||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||||
const int bias = linear_index % half_hidden_size;
|
const int bias = linear_index % half_hidden_size;
|
||||||
const int hi = bias / half_head_size; // q + k + v
|
const int hi = bias / half_head_size; // q + k + v
|
||||||
@@ -521,7 +360,15 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
|||||||
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||||
if (block_idx < 0) {
|
if (block_idx < 0) {
|
||||||
return; // NOTE(gongshaotian): For CUDAGraph padding
|
printf(
|
||||||
|
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
|
||||||
|
"%d %d %d %d\n",
|
||||||
|
block_idx,
|
||||||
|
write_seq_id,
|
||||||
|
ori_bi,
|
||||||
|
seq_lens_decoder[ori_bi],
|
||||||
|
token_id,
|
||||||
|
cu_seqlens_q[ori_bi]);
|
||||||
}
|
}
|
||||||
const int block_offset = write_seq_id % block_size;
|
const int block_offset = write_seq_id % block_size;
|
||||||
|
|
||||||
@@ -543,10 +390,8 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
|||||||
if (hi < num_heads + gqa_group_size) {
|
if (hi < num_heads + gqa_group_size) {
|
||||||
// q k rope
|
// q k rope
|
||||||
const int64_t emb_idx = write_seq_id * head_size + h_bias;
|
const int64_t emb_idx = write_seq_id * head_size + h_bias;
|
||||||
int64_t new_emb_idx =
|
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||||
rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
|
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
|
||||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < VecSize; i++) {
|
for (int i = 0; i < VecSize; i++) {
|
||||||
@@ -598,277 +443,6 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T,
|
|
||||||
int VecSize = 4,
|
|
||||||
int RoundType = 0,
|
|
||||||
int HeadDim = 128,
|
|
||||||
bool IsFP8 = false>
|
|
||||||
__global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
|
||||||
const T* __restrict__ quant_qkv, // [num_head, num_heads + 2 *
|
|
||||||
// gqa_group_size, head_size]
|
|
||||||
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
|
|
||||||
// block_size, head_size // 2]
|
|
||||||
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
|
||||||
// block_size, head_size // 2]
|
|
||||||
T* __restrict__ qkv_out,
|
|
||||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
|
||||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
|
||||||
const int* __restrict__ cu_seqlens_q,
|
|
||||||
const int* __restrict__ seq_lens, // [bsz]
|
|
||||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
|
||||||
const float* __restrict__ cos_emb,
|
|
||||||
const float* __restrict__ sin_emb,
|
|
||||||
T* __restrict__ cache_k_scale,
|
|
||||||
T* __restrict__ cache_v_scale,
|
|
||||||
const float* q_norm_weight,
|
|
||||||
const float* k_norm_weight,
|
|
||||||
const int max_seq_len,
|
|
||||||
const int max_blocks_per_seq,
|
|
||||||
const int num_heads,
|
|
||||||
const int block_size,
|
|
||||||
const float max_bound,
|
|
||||||
const float min_bound,
|
|
||||||
const int gqa_group_size,
|
|
||||||
const bool rope_3d,
|
|
||||||
const float rms_norm_eps) {
|
|
||||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
|
||||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
|
||||||
constexpr int NUM_WARPS = 4;
|
|
||||||
const int tid = threadIdx.x;
|
|
||||||
const int wid = tid / 32;
|
|
||||||
const int lane_id = tid % 32;
|
|
||||||
const int token_id = blockIdx.x;
|
|
||||||
|
|
||||||
const int bid = batch_id_per_token[token_id];
|
|
||||||
if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
|
||||||
|
|
||||||
const int start_token_idx = cu_seqlens_q[bid];
|
|
||||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
|
||||||
int q_head_idx, k_head_idx, v_idx;
|
|
||||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
|
|
||||||
constexpr int half_head_size = HeadDim / 2;
|
|
||||||
if (seq_lens_encoder[bid] > 0) return;
|
|
||||||
const int write_seq_id = seq_lens[bid] + token_id - start_token_idx;
|
|
||||||
if (write_seq_id == 0) return;
|
|
||||||
const int* block_table_now = block_tables + bid * max_blocks_per_seq;
|
|
||||||
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
|
|
||||||
const int block_offset = write_seq_id % block_size;
|
|
||||||
|
|
||||||
int cache_offset;
|
|
||||||
if (head_idx < num_heads) {
|
|
||||||
cache_offset = 0;
|
|
||||||
} else if (head_idx < num_heads + 2 * gqa_group_size) {
|
|
||||||
cache_offset = block_idx * gqa_group_size * block_size +
|
|
||||||
(head_idx - num_heads) % gqa_group_size * block_size +
|
|
||||||
block_offset;
|
|
||||||
}
|
|
||||||
T* cache_k_scale_now = cache_k_scale + cache_offset;
|
|
||||||
T* cache_v_scale_now = cache_v_scale + cache_offset;
|
|
||||||
|
|
||||||
float thread_m2 = 0.0f;
|
|
||||||
float warp_m2 = 0.0f;
|
|
||||||
|
|
||||||
if (head_idx < num_heads) {
|
|
||||||
// q
|
|
||||||
using LoadT = AlignedVector<T, VecSize>;
|
|
||||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
|
||||||
using LoadOutScaleT = AlignedVector<float, VecSize>;
|
|
||||||
constexpr int HalfVecSize = VecSize / 2;
|
|
||||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
|
||||||
|
|
||||||
LoadT src_vec;
|
|
||||||
LoadBiasT bias_vec;
|
|
||||||
LoadOutScaleT out_scale_vec;
|
|
||||||
LoadEmbT cos_emb_vec;
|
|
||||||
LoadEmbT sin_emb_vec;
|
|
||||||
const T* qkv_now = quant_qkv + token_id * hidden_size;
|
|
||||||
T* qkv_out_now = qkv_out + token_id * hidden_size;
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
|
|
||||||
head_bias += 32 * VecSize) {
|
|
||||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
|
||||||
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
|
|
||||||
|
|
||||||
// q rope
|
|
||||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
|
||||||
uint32_t new_emb_idx =
|
|
||||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
|
||||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
|
||||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < HalfVecSize; i++) {
|
|
||||||
// dequant + add_bias + rope
|
|
||||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
|
||||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
|
||||||
const float cos_tmp = cos_emb_vec[i];
|
|
||||||
const float sin_tmp = sin_emb_vec[i];
|
|
||||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
|
||||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
|
||||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
|
||||||
bias_vec[2 * i] = static_cast<T>(tmp1);
|
|
||||||
bias_vec[2 * i + 1] = static_cast<T>(tmp2);
|
|
||||||
}
|
|
||||||
// qk norm
|
|
||||||
if (q_norm_weight) {
|
|
||||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
|
||||||
float row_variance = max(warp_m2 / HeadDim, 0.0f);
|
|
||||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
|
||||||
LoadOutScaleT q_norm_vec;
|
|
||||||
Load<float, VecSize>(&q_norm_weight[lane_id * VecSize], &q_norm_vec);
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < VecSize; i++) {
|
|
||||||
bias_vec[i] = static_cast<T>(static_cast<float>(bias_vec[i]) *
|
|
||||||
row_inv_var * q_norm_vec[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Store<T, VecSize>(bias_vec, &qkv_out_now[bias_idx]);
|
|
||||||
}
|
|
||||||
} else if (head_idx < num_heads + 2 * gqa_group_size) {
|
|
||||||
// k
|
|
||||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
|
||||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
|
||||||
const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size;
|
|
||||||
|
|
||||||
constexpr int K_VEC_SIZE = 4;
|
|
||||||
constexpr int HALF_K_VEC_SIZE = 2;
|
|
||||||
using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>;
|
|
||||||
using LoadKVT = AlignedVector<uint8_t, HALF_K_VEC_SIZE>;
|
|
||||||
using LoadT = AlignedVector<T, HALF_K_VEC_SIZE>;
|
|
||||||
using LoadBiasT = AlignedVector<T, HALF_K_VEC_SIZE>;
|
|
||||||
using LoadOutScaleT = AlignedVector<float, HALF_K_VEC_SIZE>;
|
|
||||||
using LoadEmbT = AlignedVector<float, 1>;
|
|
||||||
LoadKVResT cache_vec;
|
|
||||||
LoadT src_vec1, src_vec2;
|
|
||||||
LoadBiasT bias_vec1, bias_vec2;
|
|
||||||
LoadOutScaleT out_scale_vec1, out_scale_vec2;
|
|
||||||
LoadEmbT cos_emb_vec1, cos_emb_vec2;
|
|
||||||
LoadEmbT sin_emb_vec1, sin_emb_vec2;
|
|
||||||
|
|
||||||
const T* qkv_now = quant_qkv + token_id * hidden_size;
|
|
||||||
const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2;
|
|
||||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
|
||||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
|
|
||||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
|
|
||||||
T scale = T(1.0f);
|
|
||||||
const int k_head_idx = head_idx - num_heads;
|
|
||||||
const int v_head_idx = head_idx - num_heads - gqa_group_size;
|
|
||||||
if (head_idx < num_heads + gqa_group_size) {
|
|
||||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
|
||||||
uint32_t new_emb_idx =
|
|
||||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
|
||||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
|
||||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
|
||||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
|
||||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
|
||||||
}
|
|
||||||
|
|
||||||
float input_left = static_cast<float>(src_vec1[0]);
|
|
||||||
float input_right = static_cast<float>(src_vec1[1]);
|
|
||||||
if (head_idx < num_heads + gqa_group_size) {
|
|
||||||
float cos_tmp = cos_emb_vec1[0];
|
|
||||||
float sin_tmp = sin_emb_vec1[0];
|
|
||||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
|
||||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
|
||||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
|
||||||
bias_vec1[0] = static_cast<T>(tmp1);
|
|
||||||
bias_vec1[1] = static_cast<T>(tmp2);
|
|
||||||
} else {
|
|
||||||
bias_vec1[0] = static_cast<T>(input_left);
|
|
||||||
bias_vec1[1] = static_cast<T>(input_right);
|
|
||||||
}
|
|
||||||
|
|
||||||
input_left = static_cast<float>(src_vec2[0]);
|
|
||||||
input_right = static_cast<float>(src_vec2[1]);
|
|
||||||
if (head_idx < num_heads + gqa_group_size) {
|
|
||||||
float cos_tmp = cos_emb_vec2[0];
|
|
||||||
float sin_tmp = sin_emb_vec2[0];
|
|
||||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
|
||||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
|
||||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
|
||||||
bias_vec2[0] = static_cast<T>(tmp1);
|
|
||||||
bias_vec2[1] = static_cast<T>(tmp2);
|
|
||||||
} else {
|
|
||||||
bias_vec2[0] = static_cast<T>(input_left);
|
|
||||||
bias_vec2[1] = static_cast<T>(input_right);
|
|
||||||
}
|
|
||||||
if (k_norm_weight) {
|
|
||||||
if (head_idx < num_heads + gqa_group_size) {
|
|
||||||
LoadOutScaleT k_norm_vec1, k_norm_vec2;
|
|
||||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias], &k_norm_vec1);
|
|
||||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias + 8],
|
|
||||||
&k_norm_vec2);
|
|
||||||
// qk norm
|
|
||||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
|
||||||
float row_variance = max(warp_m2 / HeadDim, 0.0f);
|
|
||||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
|
||||||
|
|
||||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
|
||||||
bias_vec1[i] = static_cast<T>(static_cast<float>(bias_vec1[i]) *
|
|
||||||
row_inv_var * k_norm_vec1[i]);
|
|
||||||
bias_vec2[i] = static_cast<T>(static_cast<float>(bias_vec2[i]) *
|
|
||||||
row_inv_var * k_norm_vec2[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// reduce max, 1 head per warp
|
|
||||||
T local_max = -INFINITY;
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
|
||||||
local_max = __hmax(local_max, __habs(bias_vec1[i]));
|
|
||||||
local_max = __hmax(local_max, __habs(bias_vec2[i]));
|
|
||||||
}
|
|
||||||
#pragma unroll
|
|
||||||
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
|
|
||||||
local_max =
|
|
||||||
__hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
|
|
||||||
}
|
|
||||||
|
|
||||||
scale = __hdiv(448, local_max);
|
|
||||||
|
|
||||||
if (lane_id == 0) {
|
|
||||||
if (head_idx < num_heads + gqa_group_size) {
|
|
||||||
cache_k_scale_now[0] = __hdiv(1, scale);
|
|
||||||
} else {
|
|
||||||
cache_v_scale_now[0] = __hdiv(1, scale);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) {
|
|
||||||
cache_vec[i] = QuantToC8<T, true, IsFP8, RoundType>(
|
|
||||||
scale, bias_vec1[i], max_bound, min_bound);
|
|
||||||
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T, true, IsFP8, RoundType>(
|
|
||||||
scale, bias_vec2[i], max_bound, min_bound);
|
|
||||||
}
|
|
||||||
if (head_idx < num_heads + gqa_group_size) {
|
|
||||||
const int start_block_16 =
|
|
||||||
block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8;
|
|
||||||
const uint32_t tgt_cache_idx =
|
|
||||||
block_idx * gqa_group_size * block_size * HeadDim +
|
|
||||||
kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim +
|
|
||||||
lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4;
|
|
||||||
Store<uint8_t, K_VEC_SIZE>(cache_vec, &key_cache[tgt_cache_idx]);
|
|
||||||
} else {
|
|
||||||
const uint32_t base_tgt_cache_idx =
|
|
||||||
block_idx * gqa_group_size * HeadDim * block_size +
|
|
||||||
kv_head_idx * HeadDim * block_size +
|
|
||||||
(lane_id / 4 * 16 + lane_id % 4 * 2) * block_size +
|
|
||||||
block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32;
|
|
||||||
const uint32_t tgt_cache_idx1 = base_tgt_cache_idx +
|
|
||||||
block_offset % 8 / 2 * 4 // per 4
|
|
||||||
+ block_offset % 16 / 8 * 2 // per 2
|
|
||||||
+ block_offset % 2; // per 1
|
|
||||||
const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size;
|
|
||||||
const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16;
|
|
||||||
const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size;
|
|
||||||
value_cache[tgt_cache_idx1] = cache_vec[0];
|
|
||||||
value_cache[tgt_cache_idx2] = cache_vec[1];
|
|
||||||
value_cache[tgt_cache_idx3] = cache_vec[2];
|
|
||||||
value_cache[tgt_cache_idx4] = cache_vec[3];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T,
|
template <typename T,
|
||||||
int VecSize = 4,
|
int VecSize = 4,
|
||||||
int RoundType = 0,
|
int RoundType = 0,
|
||||||
@@ -883,7 +457,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
|||||||
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
||||||
// block_size, head_size // 2]
|
// block_size, head_size // 2]
|
||||||
T* __restrict__ qkv_out,
|
T* __restrict__ qkv_out,
|
||||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||||
const int* __restrict__ cu_seqlens_q,
|
const int* __restrict__ cu_seqlens_q,
|
||||||
const int* __restrict__ seq_lens, // [bsz]
|
const int* __restrict__ seq_lens, // [bsz]
|
||||||
@@ -902,8 +476,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
|||||||
const int block_size,
|
const int block_size,
|
||||||
const float max_bound,
|
const float max_bound,
|
||||||
const float min_bound,
|
const float min_bound,
|
||||||
const int gqa_group_size,
|
const int gqa_group_size) {
|
||||||
const bool rope_3d) {
|
|
||||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||||
constexpr int NUM_WARPS = 4;
|
constexpr int NUM_WARPS = 4;
|
||||||
@@ -913,7 +486,6 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
|||||||
const int token_id = blockIdx.x;
|
const int token_id = blockIdx.x;
|
||||||
|
|
||||||
const int bid = batch_id_per_token[token_id];
|
const int bid = batch_id_per_token[token_id];
|
||||||
if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
|
||||||
|
|
||||||
const int start_token_idx = cu_seqlens_q[bid];
|
const int start_token_idx = cu_seqlens_q[bid];
|
||||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||||
@@ -950,10 +522,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
|||||||
|
|
||||||
// q rope
|
// q rope
|
||||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||||
uint32_t new_emb_idx =
|
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
|
||||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
|
||||||
if (qkv_out_scales) {
|
if (qkv_out_scales) {
|
||||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||||
}
|
}
|
||||||
@@ -1013,12 +583,10 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
|||||||
T scale;
|
T scale;
|
||||||
if (head_idx < num_heads + gqa_group_size) {
|
if (head_idx < num_heads + gqa_group_size) {
|
||||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||||
uint32_t new_emb_idx =
|
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
|
||||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
|
||||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||||
} else {
|
} else {
|
||||||
scale = __ldg(&cache_v_scales[kv_head_idx]);
|
scale = __ldg(&cache_v_scales[kv_head_idx]);
|
||||||
@@ -1076,10 +644,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
|||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) {
|
for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||||
cache_vec[i] = QuantToC8<T, true, IsFP8, RoundType>(
|
cache_vec[i] = QuantToC8<T,true, IsFP8, RoundType>(scale, bias_vec1[i], max_bound, min_bound);
|
||||||
scale, bias_vec1[i], max_bound, min_bound);
|
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T,true, IsFP8, RoundType>(scale, bias_vec2[i], max_bound, min_bound);
|
||||||
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T, true, IsFP8, RoundType>(
|
|
||||||
scale, bias_vec2[i], max_bound, min_bound);
|
|
||||||
}
|
}
|
||||||
if (head_idx < num_heads + gqa_group_size) {
|
if (head_idx < num_heads + gqa_group_size) {
|
||||||
const int start_block_16 =
|
const int start_block_16 =
|
||||||
@@ -1123,7 +689,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
|||||||
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
||||||
// block_size, head_size // 2]
|
// block_size, head_size // 2]
|
||||||
T* __restrict__ qkv_out,
|
T* __restrict__ qkv_out,
|
||||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||||
const int* __restrict__ cu_seqlens_q,
|
const int* __restrict__ cu_seqlens_q,
|
||||||
const int* __restrict__ seq_lens, // [bsz]
|
const int* __restrict__ seq_lens, // [bsz]
|
||||||
@@ -1142,8 +708,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
|||||||
const int block_size,
|
const int block_size,
|
||||||
const float max_bound,
|
const float max_bound,
|
||||||
const float min_bound,
|
const float min_bound,
|
||||||
const int gqa_group_size,
|
const int gqa_group_size) {
|
||||||
const bool rope_3d) {
|
|
||||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||||
constexpr int NUM_WARPS = 4;
|
constexpr int NUM_WARPS = 4;
|
||||||
@@ -1153,7 +718,6 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
|||||||
const int token_id = blockIdx.x;
|
const int token_id = blockIdx.x;
|
||||||
|
|
||||||
const int bid = batch_id_per_token[token_id];
|
const int bid = batch_id_per_token[token_id];
|
||||||
if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
|
||||||
|
|
||||||
const int start_token_idx = cu_seqlens_q[bid];
|
const int start_token_idx = cu_seqlens_q[bid];
|
||||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||||
@@ -1193,10 +757,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
|||||||
|
|
||||||
// q rope
|
// q rope
|
||||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||||
uint32_t new_emb_idx =
|
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
|
||||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
|
||||||
if (qkv_out_scales) {
|
if (qkv_out_scales) {
|
||||||
Load<float, VecSize>(&qkv_out_scales[bias_idx_left],
|
Load<float, VecSize>(&qkv_out_scales[bias_idx_left],
|
||||||
&left_out_scale_vec);
|
&left_out_scale_vec);
|
||||||
@@ -1291,12 +853,10 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
|||||||
|
|
||||||
T scale;
|
T scale;
|
||||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||||
uint32_t new_emb_idx =
|
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
|
||||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
|
||||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
|
||||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
|
|
||||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||||
@@ -1507,7 +1067,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
|||||||
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
||||||
// block_size, head_size // 2]
|
// block_size, head_size // 2]
|
||||||
T* __restrict__ qkv_out,
|
T* __restrict__ qkv_out,
|
||||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||||
const int* __restrict__ cu_seqlens_q,
|
const int* __restrict__ cu_seqlens_q,
|
||||||
const int* __restrict__ seq_lens, // [bsz]
|
const int* __restrict__ seq_lens, // [bsz]
|
||||||
@@ -1528,8 +1088,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
|||||||
const int block_size,
|
const int block_size,
|
||||||
const float max_bound,
|
const float max_bound,
|
||||||
const float min_bound,
|
const float min_bound,
|
||||||
const int gqa_group_size,
|
const int gqa_group_size) {
|
||||||
const bool rope_3d) {
|
|
||||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||||
constexpr int NUM_WARPS = 4;
|
constexpr int NUM_WARPS = 4;
|
||||||
@@ -1540,7 +1099,6 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
|||||||
const int token_id = blockIdx.x;
|
const int token_id = blockIdx.x;
|
||||||
|
|
||||||
const int bid = batch_id_per_token[token_id];
|
const int bid = batch_id_per_token[token_id];
|
||||||
if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
|
||||||
|
|
||||||
const int start_token_idx = cu_seqlens_q[bid];
|
const int start_token_idx = cu_seqlens_q[bid];
|
||||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||||
@@ -1587,10 +1145,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
|||||||
// Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
// Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||||
// q rope
|
// q rope
|
||||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||||
uint32_t new_emb_idx =
|
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
|
||||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < HalfVecSize; i++) {
|
for (int i = 0; i < HalfVecSize; i++) {
|
||||||
// dequant + add_bias + rope
|
// dequant + add_bias + rope
|
||||||
@@ -1679,12 +1235,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
|||||||
// &out_scale_vec2);
|
// &out_scale_vec2);
|
||||||
if (head_idx < num_heads + gqa_group_size) {
|
if (head_idx < num_heads + gqa_group_size) {
|
||||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||||
uint32_t new_emb_idx =
|
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
|
||||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
|
||||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx], &scale_vec1);
|
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx], &scale_vec1);
|
||||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx + 8], &scale_vec2);
|
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx + 8], &scale_vec2);
|
||||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
|
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
|
||||||
@@ -1787,6 +1341,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
|||||||
}
|
}
|
||||||
Store<uint8_t, K_VEC_SIZE>(cache_vec, &key_cache[tgt_cache_idx]);
|
Store<uint8_t, K_VEC_SIZE>(cache_vec, &key_cache[tgt_cache_idx]);
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
const uint32_t base_tgt_cache_idx =
|
const uint32_t base_tgt_cache_idx =
|
||||||
block_idx * gqa_group_size * HeadDim * half_block_size +
|
block_idx * gqa_group_size * HeadDim * half_block_size +
|
||||||
kv_head_idx * HeadDim * half_block_size +
|
kv_head_idx * HeadDim * half_block_size +
|
||||||
@@ -1855,7 +1410,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
|||||||
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
||||||
// block_size, head_size // 2]
|
// block_size, head_size // 2]
|
||||||
T* __restrict__ qkv_out,
|
T* __restrict__ qkv_out,
|
||||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||||
const int* __restrict__ cu_seqlens_q,
|
const int* __restrict__ cu_seqlens_q,
|
||||||
const int* __restrict__ seq_lens, // [bsz]
|
const int* __restrict__ seq_lens, // [bsz]
|
||||||
@@ -1876,8 +1431,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
|||||||
const int block_size,
|
const int block_size,
|
||||||
const float max_bound,
|
const float max_bound,
|
||||||
const float min_bound,
|
const float min_bound,
|
||||||
const int gqa_group_size,
|
const int gqa_group_size) {
|
||||||
const bool rope_3d) {
|
|
||||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||||
constexpr int NUM_WARPS = 4;
|
constexpr int NUM_WARPS = 4;
|
||||||
@@ -1888,7 +1442,6 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
|||||||
const int token_id = blockIdx.x;
|
const int token_id = blockIdx.x;
|
||||||
|
|
||||||
const int bid = batch_id_per_token[token_id];
|
const int bid = batch_id_per_token[token_id];
|
||||||
if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
|
||||||
|
|
||||||
const int start_token_idx = cu_seqlens_q[bid];
|
const int start_token_idx = cu_seqlens_q[bid];
|
||||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||||
@@ -2028,12 +1581,10 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
|||||||
&right_out_scale_vec2);
|
&right_out_scale_vec2);
|
||||||
|
|
||||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||||
uint32_t new_emb_idx =
|
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
|
||||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
|
||||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
|
||||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
|
|
||||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx],
|
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx],
|
||||||
&left_scale_vec1);
|
&left_scale_vec1);
|
||||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx + 8],
|
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx + 8],
|
||||||
@@ -2067,6 +1618,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
|||||||
right_bias_vec1[i] =
|
right_bias_vec1[i] =
|
||||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||||
|
|
||||||
|
|
||||||
input_left = static_cast<float>(left_src_vec2[i]);
|
input_left = static_cast<float>(left_src_vec2[i]);
|
||||||
input_right = static_cast<float>(right_src_vec2[i]);
|
input_right = static_cast<float>(right_src_vec2[i]);
|
||||||
cos_tmp = cos_emb_vec2[i];
|
cos_tmp = cos_emb_vec2[i];
|
||||||
|
|||||||
@@ -15,78 +15,6 @@
|
|||||||
#include "speculate_write_cache_with_rope_kernel.h"
|
#include "speculate_write_cache_with_rope_kernel.h"
|
||||||
#include "utils.cuh"
|
#include "utils.cuh"
|
||||||
|
|
||||||
template <typename T, typename QKV_TYPE>
|
|
||||||
void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
|
||||||
T* key_cache,
|
|
||||||
T* value_cache,
|
|
||||||
T* qkv_out,
|
|
||||||
const int* block_tables,
|
|
||||||
const int* batch_id_per_token,
|
|
||||||
const int* cu_seqlens_q,
|
|
||||||
const int* seq_lens,
|
|
||||||
const int* seq_lens_encoder,
|
|
||||||
const float* cos_emb,
|
|
||||||
const float* sin_emb,
|
|
||||||
const float* qkv_out_scales,
|
|
||||||
const T* qkv_biases,
|
|
||||||
const int max_seq_len,
|
|
||||||
const int max_blocks_per_seq,
|
|
||||||
const int num_heads,
|
|
||||||
const int kv_num_heads,
|
|
||||||
const int dim_head,
|
|
||||||
const int block_size,
|
|
||||||
const int bsz,
|
|
||||||
const int token_num,
|
|
||||||
const cudaStream_t& stream,
|
|
||||||
const bool use_neox_style,
|
|
||||||
const float* q_norm_weight,
|
|
||||||
const float* k_norm_weight,
|
|
||||||
const float rms_norm_eps,
|
|
||||||
const bool rope_3d) {
|
|
||||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
|
||||||
const uint32_t elem_nums =
|
|
||||||
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
|
||||||
: token_num * (num_heads + 2 * kv_num_heads) * dim_head;
|
|
||||||
constexpr int HEAD_DIM = 128;
|
|
||||||
|
|
||||||
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
|
||||||
const int pack_num = elem_nums / PackSize;
|
|
||||||
const int blocksize = 128;
|
|
||||||
int grid_size = 1;
|
|
||||||
GetNumBlocks<128>(pack_num, &grid_size);
|
|
||||||
if (use_neox_style) {
|
|
||||||
PD_THROW(
|
|
||||||
"append_speculate_cache_rope_qk_norm not support neox rope yet");
|
|
||||||
} else {
|
|
||||||
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
|
|
||||||
append_speculate_cache_T_rope_qk_norm_kernel<T, PackSize>
|
|
||||||
<<<grid_size, block_dim, 0, stream>>>(qkv,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
qkv_out,
|
|
||||||
block_tables,
|
|
||||||
batch_id_per_token,
|
|
||||||
cu_seqlens_q,
|
|
||||||
seq_lens,
|
|
||||||
cos_emb,
|
|
||||||
sin_emb,
|
|
||||||
qkv_out_scales,
|
|
||||||
qkv_biases,
|
|
||||||
max_seq_len,
|
|
||||||
max_blocks_per_seq,
|
|
||||||
num_heads,
|
|
||||||
output_inner_dim,
|
|
||||||
dim_head,
|
|
||||||
block_size,
|
|
||||||
elem_nums,
|
|
||||||
kv_num_heads,
|
|
||||||
q_norm_weight,
|
|
||||||
k_norm_weight,
|
|
||||||
rms_norm_eps,
|
|
||||||
rope_3d);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// rope + write
|
// rope + write
|
||||||
template <typename T, typename QKV_TYPE>
|
template <typename T, typename QKV_TYPE>
|
||||||
void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||||
@@ -111,8 +39,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
|||||||
const int bsz,
|
const int bsz,
|
||||||
const int token_num,
|
const int token_num,
|
||||||
const cudaStream_t& stream,
|
const cudaStream_t& stream,
|
||||||
const bool use_neox_style,
|
const bool use_neox_style) {
|
||||||
const bool rope_3d) {
|
|
||||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||||
|
|
||||||
const uint32_t elem_nums =
|
const uint32_t elem_nums =
|
||||||
@@ -146,8 +73,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
|||||||
dim_head,
|
dim_head,
|
||||||
block_size,
|
block_size,
|
||||||
elem_nums,
|
elem_nums,
|
||||||
kv_num_heads,
|
kv_num_heads);
|
||||||
rope_3d);
|
|
||||||
} else {
|
} else {
|
||||||
append_speculate_cache_rope_kernel<T, PackSize>
|
append_speculate_cache_rope_kernel<T, PackSize>
|
||||||
<<<grid_size, threads_per_block, 0, stream>>>(
|
<<<grid_size, threads_per_block, 0, stream>>>(
|
||||||
@@ -170,83 +96,10 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
|||||||
dim_head,
|
dim_head,
|
||||||
block_size,
|
block_size,
|
||||||
elem_nums,
|
elem_nums,
|
||||||
kv_num_heads,
|
kv_num_heads);
|
||||||
rope_3d);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void append_speculate_cache_fp8_dynamic_rope(const T* qkv,
|
|
||||||
uint8_t* key_cache,
|
|
||||||
uint8_t* value_cache,
|
|
||||||
T* qkv_out,
|
|
||||||
const int* block_tables,
|
|
||||||
const int* batch_id_per_token,
|
|
||||||
const int* cu_seqlens_q,
|
|
||||||
const int* seq_lens,
|
|
||||||
const int* seq_lens_encoder,
|
|
||||||
const float* cos_emb,
|
|
||||||
const float* sin_emb,
|
|
||||||
T* cache_k_scale,
|
|
||||||
T* cache_v_scale,
|
|
||||||
const float* q_norm_weight,
|
|
||||||
const float* k_norm_weight,
|
|
||||||
const int max_seq_len,
|
|
||||||
const int max_blocks_per_seq,
|
|
||||||
const int num_heads,
|
|
||||||
const int kv_num_heads,
|
|
||||||
const int dim_head,
|
|
||||||
const int block_size,
|
|
||||||
const int bsz,
|
|
||||||
const int token_num,
|
|
||||||
const cudaStream_t& stream,
|
|
||||||
const bool rope_3d,
|
|
||||||
const float rms_norm_eps) {
|
|
||||||
constexpr int num_warps = 4;
|
|
||||||
const int all_warps =
|
|
||||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
|
||||||
dim3 grids(token_num, all_warps / num_warps);
|
|
||||||
|
|
||||||
append_clear_cache_int8_block<4>
|
|
||||||
<<<grids, num_warps * 32, 0, stream>>>(key_cache,
|
|
||||||
value_cache,
|
|
||||||
seq_lens,
|
|
||||||
block_tables,
|
|
||||||
batch_id_per_token,
|
|
||||||
cu_seqlens_q,
|
|
||||||
seq_lens_encoder,
|
|
||||||
max_seq_len,
|
|
||||||
max_blocks_per_seq,
|
|
||||||
num_heads,
|
|
||||||
block_size,
|
|
||||||
kv_num_heads);
|
|
||||||
append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel<T, 4, 0, 128, true>
|
|
||||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
qkv_out,
|
|
||||||
block_tables,
|
|
||||||
batch_id_per_token,
|
|
||||||
cu_seqlens_q,
|
|
||||||
seq_lens,
|
|
||||||
seq_lens_encoder,
|
|
||||||
cos_emb,
|
|
||||||
sin_emb,
|
|
||||||
cache_k_scale,
|
|
||||||
cache_v_scale,
|
|
||||||
q_norm_weight,
|
|
||||||
k_norm_weight,
|
|
||||||
max_seq_len,
|
|
||||||
max_blocks_per_seq,
|
|
||||||
num_heads,
|
|
||||||
block_size,
|
|
||||||
127.0f,
|
|
||||||
-127.0f,
|
|
||||||
kv_num_heads,
|
|
||||||
rope_3d,
|
|
||||||
rms_norm_eps);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename QKV_TYPE, bool IsFP8=false>
|
template <typename T, typename QKV_TYPE, bool IsFP8=false>
|
||||||
void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||||
uint8_t* key_cache,
|
uint8_t* key_cache,
|
||||||
@@ -272,8 +125,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
|||||||
const int bsz,
|
const int bsz,
|
||||||
const int token_num,
|
const int token_num,
|
||||||
const cudaStream_t& stream,
|
const cudaStream_t& stream,
|
||||||
const bool use_neox_style,
|
const bool use_neox_style) {
|
||||||
const bool rope_3d) {
|
|
||||||
constexpr int num_warps = 4;
|
constexpr int num_warps = 4;
|
||||||
const int all_warps =
|
const int all_warps =
|
||||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||||
@@ -315,8 +167,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
|||||||
block_size,
|
block_size,
|
||||||
127.0f,
|
127.0f,
|
||||||
-127.0f,
|
-127.0f,
|
||||||
kv_num_heads,
|
kv_num_heads);
|
||||||
rope_3d);
|
|
||||||
} else {
|
} else {
|
||||||
append_speculate_cache_int8_rope_kernel<T, 4, 0, 128, QKV_TYPE, IsFP8>
|
append_speculate_cache_int8_rope_kernel<T, 4, 0, 128, QKV_TYPE, IsFP8>
|
||||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||||
@@ -340,8 +191,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
|||||||
block_size,
|
block_size,
|
||||||
127.0f,
|
127.0f,
|
||||||
-127.0f,
|
-127.0f,
|
||||||
kv_num_heads,
|
kv_num_heads);
|
||||||
rope_3d);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -372,8 +222,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
|||||||
const int bsz,
|
const int bsz,
|
||||||
const int token_num,
|
const int token_num,
|
||||||
const cudaStream_t& stream,
|
const cudaStream_t& stream,
|
||||||
const bool use_neox_style,
|
const bool use_neox_style) {
|
||||||
const bool rope_3d) {
|
|
||||||
constexpr int num_warps = 4;
|
constexpr int num_warps = 4;
|
||||||
const int all_warps =
|
const int all_warps =
|
||||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||||
@@ -417,8 +266,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
|||||||
block_size,
|
block_size,
|
||||||
7.0f,
|
7.0f,
|
||||||
-8.0f,
|
-8.0f,
|
||||||
kv_num_heads,
|
kv_num_heads);
|
||||||
rope_3d);
|
|
||||||
} else {
|
} else {
|
||||||
append_speculate_cache_int4_rope_kernel<T, 4>
|
append_speculate_cache_int4_rope_kernel<T, 4>
|
||||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||||
@@ -444,8 +292,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
|||||||
block_size,
|
block_size,
|
||||||
7.0f,
|
7.0f,
|
||||||
-8.0f,
|
-8.0f,
|
||||||
kv_num_heads,
|
kv_num_heads);
|
||||||
rope_3d);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
template <typename T, typename QKV_TYPE>
|
template <typename T, typename QKV_TYPE>
|
||||||
@@ -466,15 +313,11 @@ void SpeculateWriteCacheWithRoPEKernel(
|
|||||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const bool use_neox_rotary_style,
|
const bool use_neox_rotary_style,
|
||||||
const bool rope_3d,
|
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out,
|
paddle::Tensor* value_cache_out) {
|
||||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
|
||||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
|
||||||
const float rms_norm_eps) {
|
|
||||||
typedef cascade_attn_type_traits<T> traits_;
|
typedef cascade_attn_type_traits<T> traits_;
|
||||||
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
|
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
|
||||||
typedef typename traits_::type DataType_;
|
typedef typename traits_::type DataType_;
|
||||||
@@ -499,243 +342,142 @@ void SpeculateWriteCacheWithRoPEKernel(
|
|||||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||||
}
|
}
|
||||||
if (q_norm_weight && k_norm_weight) {
|
if (cache_quant_type_str == "none") {
|
||||||
if (cache_quant_type_str == "none") {
|
append_speculate_cache_rope(
|
||||||
append_speculate_cache_rope_qk_norm(
|
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
block_tables.data<int>(),
|
||||||
block_tables.data<int>(),
|
batch_id_per_token.data<int>(),
|
||||||
batch_id_per_token.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
seq_lens.data<int>(),
|
||||||
seq_lens.data<int>(),
|
seq_lens_encoder.data<int>(),
|
||||||
seq_lens_encoder.data<int>(),
|
cos_emb,
|
||||||
cos_emb,
|
sin_emb,
|
||||||
sin_emb,
|
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
: nullptr,
|
||||||
: nullptr,
|
max_seq_len,
|
||||||
max_seq_len,
|
max_blocks_per_seq,
|
||||||
max_blocks_per_seq,
|
num_heads,
|
||||||
num_heads,
|
kv_num_heads,
|
||||||
kv_num_heads,
|
dim_head,
|
||||||
dim_head,
|
block_size,
|
||||||
block_size,
|
bsz,
|
||||||
bsz,
|
token_nums,
|
||||||
token_nums,
|
stream,
|
||||||
stream,
|
use_neox_rotary_style);
|
||||||
use_neox_rotary_style,
|
} else if (cache_quant_type_str == "cache_int8") {
|
||||||
reinterpret_cast<const float*>(q_norm_weight.get().data<float>()),
|
append_speculate_cache_int8_rope(
|
||||||
reinterpret_cast<const float*>(k_norm_weight.get().data<float>()),
|
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||||
rms_norm_eps,
|
key_cache_out->data<uint8_t>(),
|
||||||
rope_3d);
|
value_cache_out->data<uint8_t>(),
|
||||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||||
append_speculate_cache_fp8_dynamic_rope(
|
block_tables.data<int>(),
|
||||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
batch_id_per_token.data<int>(),
|
||||||
key_cache_out->data<uint8_t>(),
|
cu_seqlens_q.data<int>(),
|
||||||
value_cache_out->data<uint8_t>(),
|
seq_lens.data<int>(),
|
||||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
seq_lens_encoder.data<int>(),
|
||||||
block_tables.data<int>(),
|
cos_emb,
|
||||||
batch_id_per_token.data<int>(),
|
sin_emb,
|
||||||
cu_seqlens_q.data<int>(),
|
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||||
seq_lens.data<int>(),
|
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||||
seq_lens_encoder.data<int>(),
|
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||||
cos_emb,
|
: nullptr,
|
||||||
sin_emb,
|
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_v_scale.get().data<T>())),
|
: nullptr,
|
||||||
q_norm_weight.get().data<float>(),
|
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||||
k_norm_weight.get().data<float>(),
|
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||||
max_seq_len,
|
: nullptr,
|
||||||
max_blocks_per_seq,
|
max_seq_len,
|
||||||
num_heads,
|
max_blocks_per_seq,
|
||||||
kv_num_heads,
|
num_heads,
|
||||||
dim_head,
|
kv_num_heads,
|
||||||
block_size,
|
dim_head,
|
||||||
bsz,
|
block_size,
|
||||||
token_nums,
|
bsz,
|
||||||
stream,
|
token_nums,
|
||||||
rope_3d,
|
stream,
|
||||||
rms_norm_eps
|
use_neox_rotary_style);
|
||||||
);
|
} else if (cache_quant_type_str == "cache_fp8") {
|
||||||
} else {
|
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||||
PD_THROW(
|
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||||
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
key_cache_out->data<uint8_t>(),
|
||||||
}
|
value_cache_out->data<uint8_t>(),
|
||||||
|
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||||
|
block_tables.data<int>(),
|
||||||
|
batch_id_per_token.data<int>(),
|
||||||
|
cu_seqlens_q.data<int>(),
|
||||||
|
seq_lens.data<int>(),
|
||||||
|
seq_lens_encoder.data<int>(),
|
||||||
|
cos_emb,
|
||||||
|
sin_emb,
|
||||||
|
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||||
|
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||||
|
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||||
|
: nullptr,
|
||||||
|
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||||
|
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||||
|
: nullptr,
|
||||||
|
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||||
|
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||||
|
: nullptr,
|
||||||
|
max_seq_len,
|
||||||
|
max_blocks_per_seq,
|
||||||
|
num_heads,
|
||||||
|
kv_num_heads,
|
||||||
|
dim_head,
|
||||||
|
block_size,
|
||||||
|
bsz,
|
||||||
|
token_nums,
|
||||||
|
stream,
|
||||||
|
use_neox_rotary_style);
|
||||||
|
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||||
|
append_speculate_cache_int4_rope(
|
||||||
|
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||||
|
key_cache_out->data<uint8_t>(),
|
||||||
|
value_cache_out->data<uint8_t>(),
|
||||||
|
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||||
|
block_tables.data<int>(),
|
||||||
|
batch_id_per_token.data<int>(),
|
||||||
|
cu_seqlens_q.data<int>(),
|
||||||
|
seq_lens.data<int>(),
|
||||||
|
seq_lens_encoder.data<int>(),
|
||||||
|
cos_emb,
|
||||||
|
sin_emb,
|
||||||
|
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||||
|
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||||
|
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||||
|
: nullptr,
|
||||||
|
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||||
|
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||||
|
: nullptr,
|
||||||
|
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||||
|
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||||
|
: nullptr,
|
||||||
|
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||||
|
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||||
|
: nullptr,
|
||||||
|
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||||
|
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||||
|
: nullptr,
|
||||||
|
max_seq_len,
|
||||||
|
max_blocks_per_seq,
|
||||||
|
num_heads,
|
||||||
|
kv_num_heads,
|
||||||
|
dim_head,
|
||||||
|
block_size,
|
||||||
|
bsz,
|
||||||
|
token_nums,
|
||||||
|
stream,
|
||||||
|
use_neox_rotary_style);
|
||||||
} else {
|
} else {
|
||||||
if (cache_quant_type_str == "none") {
|
PD_THROW(
|
||||||
append_speculate_cache_rope(
|
"cache_quant_type_str should be one of [none, cache_int8, "
|
||||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
"cache_int4_zp]");
|
||||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
|
||||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
|
||||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
|
||||||
block_tables.data<int>(),
|
|
||||||
batch_id_per_token.data<int>(),
|
|
||||||
cu_seqlens_q.data<int>(),
|
|
||||||
seq_lens.data<int>(),
|
|
||||||
seq_lens_encoder.data<int>(),
|
|
||||||
cos_emb,
|
|
||||||
sin_emb,
|
|
||||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
|
||||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
|
||||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
max_seq_len,
|
|
||||||
max_blocks_per_seq,
|
|
||||||
num_heads,
|
|
||||||
kv_num_heads,
|
|
||||||
dim_head,
|
|
||||||
block_size,
|
|
||||||
bsz,
|
|
||||||
token_nums,
|
|
||||||
stream,
|
|
||||||
use_neox_rotary_style,
|
|
||||||
rope_3d);
|
|
||||||
} else if (cache_quant_type_str == "cache_int8") {
|
|
||||||
append_speculate_cache_int8_rope(
|
|
||||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
|
||||||
key_cache_out->data<uint8_t>(),
|
|
||||||
value_cache_out->data<uint8_t>(),
|
|
||||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
|
||||||
block_tables.data<int>(),
|
|
||||||
batch_id_per_token.data<int>(),
|
|
||||||
cu_seqlens_q.data<int>(),
|
|
||||||
seq_lens.data<int>(),
|
|
||||||
seq_lens_encoder.data<int>(),
|
|
||||||
cos_emb,
|
|
||||||
sin_emb,
|
|
||||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
|
||||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
|
||||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
|
||||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
|
||||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
max_seq_len,
|
|
||||||
max_blocks_per_seq,
|
|
||||||
num_heads,
|
|
||||||
kv_num_heads,
|
|
||||||
dim_head,
|
|
||||||
block_size,
|
|
||||||
bsz,
|
|
||||||
token_nums,
|
|
||||||
stream,
|
|
||||||
use_neox_rotary_style,
|
|
||||||
rope_3d);
|
|
||||||
} else if (cache_quant_type_str == "cache_fp8") {
|
|
||||||
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
|
||||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
|
||||||
key_cache_out->data<uint8_t>(),
|
|
||||||
value_cache_out->data<uint8_t>(),
|
|
||||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
|
||||||
block_tables.data<int>(),
|
|
||||||
batch_id_per_token.data<int>(),
|
|
||||||
cu_seqlens_q.data<int>(),
|
|
||||||
seq_lens.data<int>(),
|
|
||||||
seq_lens_encoder.data<int>(),
|
|
||||||
cos_emb,
|
|
||||||
sin_emb,
|
|
||||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
|
||||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
|
||||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
|
||||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
|
||||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
max_seq_len,
|
|
||||||
max_blocks_per_seq,
|
|
||||||
num_heads,
|
|
||||||
kv_num_heads,
|
|
||||||
dim_head,
|
|
||||||
block_size,
|
|
||||||
bsz,
|
|
||||||
token_nums,
|
|
||||||
stream,
|
|
||||||
use_neox_rotary_style,
|
|
||||||
rope_3d);
|
|
||||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
|
||||||
append_speculate_cache_fp8_dynamic_rope(
|
|
||||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
|
||||||
key_cache_out->data<uint8_t>(),
|
|
||||||
value_cache_out->data<uint8_t>(),
|
|
||||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
|
||||||
block_tables.data<int>(),
|
|
||||||
batch_id_per_token.data<int>(),
|
|
||||||
cu_seqlens_q.data<int>(),
|
|
||||||
seq_lens.data<int>(),
|
|
||||||
seq_lens_encoder.data<int>(),
|
|
||||||
cos_emb,
|
|
||||||
sin_emb,
|
|
||||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
|
||||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_v_scale.get().data<T>())),
|
|
||||||
nullptr, // q_norm_weight
|
|
||||||
nullptr, // k_norm_weight
|
|
||||||
max_seq_len,
|
|
||||||
max_blocks_per_seq,
|
|
||||||
num_heads,
|
|
||||||
kv_num_heads,
|
|
||||||
dim_head,
|
|
||||||
block_size,
|
|
||||||
bsz,
|
|
||||||
token_nums,
|
|
||||||
stream,
|
|
||||||
rope_3d,
|
|
||||||
rms_norm_eps
|
|
||||||
);
|
|
||||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
|
||||||
append_speculate_cache_int4_rope(
|
|
||||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
|
||||||
key_cache_out->data<uint8_t>(),
|
|
||||||
value_cache_out->data<uint8_t>(),
|
|
||||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
|
||||||
block_tables.data<int>(),
|
|
||||||
batch_id_per_token.data<int>(),
|
|
||||||
cu_seqlens_q.data<int>(),
|
|
||||||
seq_lens.data<int>(),
|
|
||||||
seq_lens_encoder.data<int>(),
|
|
||||||
cos_emb,
|
|
||||||
sin_emb,
|
|
||||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
|
||||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
|
||||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
|
||||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
|
||||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
|
||||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
|
||||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
max_seq_len,
|
|
||||||
max_blocks_per_seq,
|
|
||||||
num_heads,
|
|
||||||
kv_num_heads,
|
|
||||||
dim_head,
|
|
||||||
block_size,
|
|
||||||
bsz,
|
|
||||||
token_nums,
|
|
||||||
stream,
|
|
||||||
use_neox_rotary_style,
|
|
||||||
rope_3d);
|
|
||||||
} else {
|
|
||||||
PD_THROW(
|
|
||||||
"cache_quant_type_str should be one of [none, cache_int8, "
|
|
||||||
"cache_int4_zp]");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -758,15 +500,11 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
|||||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const bool use_neox_rotary_style,
|
const bool use_neox_rotary_style,
|
||||||
const bool rope_3d,
|
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out,
|
paddle::Tensor* value_cache_out);
|
||||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
|
||||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
|
||||||
const float rms_norm_eps);
|
|
||||||
|
|
||||||
template void
|
template void
|
||||||
SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||||
@@ -788,15 +526,11 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
|||||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const bool use_neox_rotary_style,
|
const bool use_neox_rotary_style,
|
||||||
const bool rope_3d,
|
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out,
|
paddle::Tensor* value_cache_out);
|
||||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
|
||||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
|
||||||
const float rms_norm_eps);
|
|
||||||
|
|
||||||
template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||||
const AppendAttnMetaData& meta_data,
|
const AppendAttnMetaData& meta_data,
|
||||||
@@ -817,15 +551,11 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
|||||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const bool use_neox_rotary_style,
|
const bool use_neox_rotary_style,
|
||||||
const bool rope_3d,
|
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out,
|
paddle::Tensor* value_cache_out);
|
||||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
|
||||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
|
||||||
const float rms_norm_eps);
|
|
||||||
|
|
||||||
|
|
||||||
template void
|
template void
|
||||||
@@ -848,12 +578,8 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
|||||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const bool use_neox_rotary_style,
|
const bool use_neox_rotary_style,
|
||||||
const bool rope_3d,
|
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out,
|
paddle::Tensor* value_cache_out);
|
||||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
|
||||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
|
||||||
const float rms_norm_eps);
|
|
||||||
|
|||||||
@@ -35,12 +35,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
|||||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const bool use_neox_rotary_style,
|
const bool use_neox_rotary_style,
|
||||||
const bool rope_3d,
|
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out,
|
paddle::Tensor* value_cache_out);
|
||||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
|
||||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
|
||||||
const float rms_norm_eps);
|
|
||||||
|
|||||||
@@ -1,144 +0,0 @@
|
|||||||
{
|
|
||||||
"multiquery_attention_c8": {
|
|
||||||
"name": "multiquery_attention_c8",
|
|
||||||
"function_name": "MultiQueryAppendC8Attention",
|
|
||||||
"impl_file": "multiquery_attention_c8_impl.cuh",
|
|
||||||
"template_params": [
|
|
||||||
"T",
|
|
||||||
"GROUP_SIZE",
|
|
||||||
"HEAD_DIM",
|
|
||||||
"BLOCK_SIZE",
|
|
||||||
"CAUSAL",
|
|
||||||
"BLOCK_SHAPE_Q",
|
|
||||||
"NUM_WARP_Q",
|
|
||||||
"OutT",
|
|
||||||
"ENABLE_PREFILL",
|
|
||||||
"IsFP8",
|
|
||||||
"IsDynamicC8"
|
|
||||||
],
|
|
||||||
"dispatch_params": {
|
|
||||||
"GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16],
|
|
||||||
"HEAD_DIM": [128],
|
|
||||||
"BLOCK_SIZE": [64],
|
|
||||||
"CAUSAL": [0, 1],
|
|
||||||
"BLOCK_SHAPE_Q": [16, 32, 64, 128],
|
|
||||||
"ENABLE_PREFILL": [0, 1],
|
|
||||||
"IsFP8": [0, 1],
|
|
||||||
"IsDynamicC8": [0, 1]
|
|
||||||
},
|
|
||||||
"data_types": [
|
|
||||||
["paddle::float16", "paddle::float16", "float16_float16"],
|
|
||||||
["paddle::float16", "paddle::float8_e4m3fn", "float16_fp8"],
|
|
||||||
["paddle::float16", "int8_t", "float16_int8"],
|
|
||||||
["paddle::bfloat16", "paddle::bfloat16", "bfloat16_bfloat16"],
|
|
||||||
["paddle::bfloat16", "paddle::float8_e4m3fn", "bfloat16_fp8"],
|
|
||||||
["paddle::bfloat16", "int8_t", "bfloat16_int8"]
|
|
||||||
],
|
|
||||||
"max_instances_per_file": 80,
|
|
||||||
"file_prefix": "multiquery_attention_c8_",
|
|
||||||
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::optional<paddle::Tensor> &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window);\n\n"
|
|
||||||
},
|
|
||||||
"multiquery_attention_c4": {
|
|
||||||
"name": "multiquery_attention_c4",
|
|
||||||
"function_name": "MultiQueryAppendC4Attention",
|
|
||||||
"impl_file": "multiquery_attention_c4_impl.cuh",
|
|
||||||
"template_params": [
|
|
||||||
"T",
|
|
||||||
"GROUP_SIZE",
|
|
||||||
"HEAD_DIM",
|
|
||||||
"BLOCK_SIZE",
|
|
||||||
"CAUSAL",
|
|
||||||
"BLOCK_SHAPE_Q",
|
|
||||||
"NUM_WARP_Q",
|
|
||||||
"OutT",
|
|
||||||
"ENABLE_PREFILL"
|
|
||||||
],
|
|
||||||
"dispatch_params": {
|
|
||||||
"GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16],
|
|
||||||
"HEAD_DIM": [128],
|
|
||||||
"BLOCK_SIZE": [64],
|
|
||||||
"CAUSAL": [0, 1],
|
|
||||||
"BLOCK_SHAPE_Q": [16, 32, 64, 128],
|
|
||||||
"ENABLE_PREFILL": [0, 1]
|
|
||||||
},
|
|
||||||
"data_types": [
|
|
||||||
["paddle::float16", "paddle::float16", "float16_float16"],
|
|
||||||
["paddle::float16", "paddle::float8_e4m3fn", "float16_fp8"],
|
|
||||||
["paddle::float16", "int8_t", "float16_int8"],
|
|
||||||
["paddle::bfloat16", "paddle::bfloat16", "bfloat16_bfloat16"],
|
|
||||||
["paddle::bfloat16", "paddle::float8_e4m3fn", "bfloat16_fp8"],
|
|
||||||
["paddle::bfloat16", "int8_t", "bfloat16_int8"]
|
|
||||||
],
|
|
||||||
"max_instances_per_file": 160,
|
|
||||||
"file_prefix": "multiquery_attention_c4_",
|
|
||||||
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::Tensor &cache_k_scale,\n const paddle::Tensor &cache_v_scale,\n const paddle::optional<paddle::Tensor> &cache_k_zp,\n const paddle::optional<paddle::Tensor> &cache_v_zp,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::optional<paddle::Tensor> &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window);\n\n"
|
|
||||||
},
|
|
||||||
"multiquery_attention_c16": {
|
|
||||||
"name": "multiquery_attention_c16",
|
|
||||||
"function_name": "MultiQueryAppendAttention",
|
|
||||||
"impl_file": "multiquery_attention_c16_impl.cuh",
|
|
||||||
"template_params": [
|
|
||||||
"T",
|
|
||||||
"GROUP_SIZE",
|
|
||||||
"HEAD_DIM",
|
|
||||||
"BLOCK_SIZE",
|
|
||||||
"CAUSAL",
|
|
||||||
"BLOCK_SHAPE_Q",
|
|
||||||
"NUM_WARP_Q",
|
|
||||||
"OutT",
|
|
||||||
"ENABLE_PREFILL"
|
|
||||||
],
|
|
||||||
"dispatch_params": {
|
|
||||||
"GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16],
|
|
||||||
"HEAD_DIM": [64,128],
|
|
||||||
"BLOCK_SIZE": [64],
|
|
||||||
"CAUSAL": [0, 1],
|
|
||||||
"BLOCK_SHAPE_Q": [16, 32, 64, 128],
|
|
||||||
"ENABLE_PREFILL": [0, 1]
|
|
||||||
},
|
|
||||||
"data_types": [
|
|
||||||
["paddle::float16", "paddle::float16", "float16_float16"],
|
|
||||||
["paddle::float16", "paddle::float8_e4m3fn", "float16_fp8"],
|
|
||||||
["paddle::float16", "int8_t", "float16_int8"],
|
|
||||||
["paddle::bfloat16", "paddle::bfloat16", "bfloat16_bfloat16"],
|
|
||||||
["paddle::bfloat16", "paddle::float8_e4m3fn", "bfloat16_fp8"],
|
|
||||||
["paddle::bfloat16", "int8_t", "bfloat16_int8"]
|
|
||||||
],
|
|
||||||
"max_instances_per_file": 160,
|
|
||||||
"file_prefix": "multiquery_attention_c16_",
|
|
||||||
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData &meta_data,\n const paddle::Tensor &qkv,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor> &attn_mask,\n const paddle::optional<paddle::Tensor> &shift_bias,\n const paddle::optional<paddle::Tensor> &smooth_weight,\n const paddle::optional<paddle::Tensor> &sinks,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &seq_lens_encoder,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const paddle::Tensor &batch_ids,\n const paddle::Tensor &tile_ids_per_batch,\n const int num_blocks_x_cpu,\n const int max_seq_len,\n const int max_dec_len,\n const float quant_max_bound,\n const float quant_min_bound,\n const float in_scale,\n const int max_partition_size,\n const int encoder_max_partition_size,\n const int speculate_max_draft_token_num,\n const bool is_decoder,\n cudaStream_t &stream,\n paddle::Tensor *out,\n const int sliding_window);\n\n"
|
|
||||||
},
|
|
||||||
"multiquery_decoder_attention": {
|
|
||||||
"name": "multiquery_decoder_attention",
|
|
||||||
"function_name": "MultiQueryDecoderAttention",
|
|
||||||
"impl_file": "multiquery_decoder_attention_impl.cuh",
|
|
||||||
"template_params": [
|
|
||||||
"T",
|
|
||||||
"GROUP_SIZE",
|
|
||||||
"HEAD_DIM_QK",
|
|
||||||
"HEAD_DIM_V",
|
|
||||||
"BLOCK_SIZE",
|
|
||||||
"CAUSAL",
|
|
||||||
"NUM_STAGE",
|
|
||||||
"cache_bytes",
|
|
||||||
"DEAL_EACH_TIME"
|
|
||||||
],
|
|
||||||
"dispatch_params": {
|
|
||||||
"GROUP_SIZE": [8, 16, 128],
|
|
||||||
"HEAD_DIM_QK": [128, 192, 512, 576],
|
|
||||||
"HEAD_DIM_V": [128, 192, 512, 576],
|
|
||||||
"BLOCK_SIZE": [64],
|
|
||||||
"CAUSAL": [0, 1],
|
|
||||||
"NUM_STAGE": [2],
|
|
||||||
"cache_bytes": [16],
|
|
||||||
"DEAL_EACH_TIME": [32, 64]
|
|
||||||
},
|
|
||||||
"data_types": [
|
|
||||||
["paddle::float16", "", "float16"],
|
|
||||||
["paddle::bfloat16", "", "bfloat16"]
|
|
||||||
],
|
|
||||||
"max_instances_per_file": 60,
|
|
||||||
"file_prefix": "multiquery_decoder_attention_",
|
|
||||||
"function_signature": "template void {function_name}{template_args}(\n const AppendAttnMetaData& meta_data,\n cudaStream_t &stream,\n const paddle::Tensor &q,\n const paddle::Tensor &cache_k,\n const paddle::Tensor &cache_v,\n const paddle::optional<paddle::Tensor>& attn_mask,\n const paddle::optional<paddle::Tensor>& shift_bias,\n const paddle::optional<paddle::Tensor>& smooth_weight,\n const paddle::Tensor &seq_lens_q,\n const paddle::Tensor &seq_lens_kv,\n const paddle::Tensor &batch_id_per_token,\n const paddle::Tensor &cu_seqlens_q,\n const paddle::Tensor &block_table,\n const int max_seq_len,\n const int max_dec_len,\n const float rope_scale,\n const float rope_theta,\n const float softmax_scale,\n const float in_scale,\n paddle::Tensor *out);\n\n"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
#include "../append_attention_c16_impl.cuh"
|
||||||
|
|
||||||
|
|
||||||
|
template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16>(
|
||||||
|
const AppendAttnMetaData& meta_data,
|
||||||
|
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim]
|
||||||
|
const paddle::Tensor&
|
||||||
|
cache_k, // [max_block_num, num_heads, block_size, head_dim]
|
||||||
|
const paddle::Tensor&
|
||||||
|
cache_v, // [max_block_num, num_heads, head_dim, block_size]
|
||||||
|
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_k_scale, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_v_scale, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_k_zp, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_v_zp, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
shift_bias, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
smooth_weight, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::Tensor& seq_lens_q,
|
||||||
|
const paddle::Tensor& seq_lens_kv,
|
||||||
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& batch_id_per_token,
|
||||||
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
|
const paddle::Tensor& block_table,
|
||||||
|
const paddle::Tensor& batch_ids,
|
||||||
|
const paddle::Tensor& tile_ids_per_batch,
|
||||||
|
const int num_blocks,
|
||||||
|
const int block_shape_q,
|
||||||
|
const int max_seq_len,
|
||||||
|
const int max_dec_len,
|
||||||
|
const float quant_max_bound,
|
||||||
|
const float quant_min_bound,
|
||||||
|
const float in_scale,
|
||||||
|
const int max_partition_size,
|
||||||
|
const int encoder_max_partition_size,
|
||||||
|
const int speculate_max_draft_token_num,
|
||||||
|
const bool causal,
|
||||||
|
const bool is_decoder,
|
||||||
|
const bool enable_prefill,
|
||||||
|
cudaStream_t& stream,
|
||||||
|
paddle::Tensor* out);
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
#include "../append_attention_c16_impl.cuh"
|
||||||
|
|
||||||
|
template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e4m3fn>(
|
||||||
|
const AppendAttnMetaData& meta_data,
|
||||||
|
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim]
|
||||||
|
const paddle::Tensor&
|
||||||
|
cache_k, // [max_block_num, num_heads, block_size, head_dim]
|
||||||
|
const paddle::Tensor&
|
||||||
|
cache_v, // [max_block_num, num_heads, head_dim, block_size]
|
||||||
|
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_k_scale, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_v_scale, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_k_zp, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_v_zp, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
shift_bias, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
smooth_weight, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::Tensor& seq_lens_q,
|
||||||
|
const paddle::Tensor& seq_lens_kv,
|
||||||
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& batch_id_per_token,
|
||||||
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
|
const paddle::Tensor& block_table,
|
||||||
|
const paddle::Tensor& batch_ids,
|
||||||
|
const paddle::Tensor& tile_ids_per_batch,
|
||||||
|
const int num_blocks,
|
||||||
|
const int block_shape_q,
|
||||||
|
const int max_seq_len,
|
||||||
|
const int max_dec_len,
|
||||||
|
const float quant_max_bound,
|
||||||
|
const float quant_min_bound,
|
||||||
|
const float in_scale,
|
||||||
|
const int max_partition_size,
|
||||||
|
const int encoder_max_partition_size,
|
||||||
|
const int speculate_max_draft_token_num,
|
||||||
|
const bool causal,
|
||||||
|
const bool is_decoder,
|
||||||
|
const bool enable_prefill,
|
||||||
|
cudaStream_t& stream,
|
||||||
|
paddle::Tensor* out);
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
#include "../append_attention_c16_impl.cuh"
|
||||||
|
|
||||||
|
template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
|
||||||
|
const AppendAttnMetaData& meta_data,
|
||||||
|
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim]
|
||||||
|
const paddle::Tensor&
|
||||||
|
cache_k, // [max_block_num, num_heads, block_size, head_dim]
|
||||||
|
const paddle::Tensor&
|
||||||
|
cache_v, // [max_block_num, num_heads, head_dim, block_size]
|
||||||
|
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_k_scale, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_v_scale, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_k_zp, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_v_zp, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
shift_bias, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
smooth_weight, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::Tensor& seq_lens_q,
|
||||||
|
const paddle::Tensor& seq_lens_kv,
|
||||||
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& batch_id_per_token,
|
||||||
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
|
const paddle::Tensor& block_table,
|
||||||
|
const paddle::Tensor& batch_ids,
|
||||||
|
const paddle::Tensor& tile_ids_per_batch,
|
||||||
|
const int num_blocks,
|
||||||
|
const int block_shape_q,
|
||||||
|
const int max_seq_len,
|
||||||
|
const int max_dec_len,
|
||||||
|
const float quant_max_bound,
|
||||||
|
const float quant_min_bound,
|
||||||
|
const float in_scale,
|
||||||
|
const int max_partition_size,
|
||||||
|
const int encoder_max_partition_size,
|
||||||
|
const int speculate_max_draft_token_num,
|
||||||
|
const bool causal,
|
||||||
|
const bool is_decoder,
|
||||||
|
const bool enable_prefill,
|
||||||
|
cudaStream_t& stream,
|
||||||
|
paddle::Tensor* out);
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
#include "../append_attention_c16_impl.cuh"
|
||||||
|
|
||||||
|
template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
|
||||||
|
const AppendAttnMetaData& meta_data,
|
||||||
|
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim]
|
||||||
|
const paddle::Tensor&
|
||||||
|
cache_k, // [max_block_num, num_heads, block_size, head_dim]
|
||||||
|
const paddle::Tensor&
|
||||||
|
cache_v, // [max_block_num, num_heads, head_dim, block_size]
|
||||||
|
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_k_scale, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_v_scale, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_k_zp, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_v_zp, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
shift_bias, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
smooth_weight, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::Tensor& seq_lens_q,
|
||||||
|
const paddle::Tensor& seq_lens_kv,
|
||||||
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& batch_id_per_token,
|
||||||
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
|
const paddle::Tensor& block_table,
|
||||||
|
const paddle::Tensor& batch_ids,
|
||||||
|
const paddle::Tensor& tile_ids_per_batch,
|
||||||
|
const int num_blocks,
|
||||||
|
const int block_shape_q,
|
||||||
|
const int max_seq_len,
|
||||||
|
const int max_dec_len,
|
||||||
|
const float quant_max_bound,
|
||||||
|
const float quant_min_bound,
|
||||||
|
const float in_scale,
|
||||||
|
const int max_partition_size,
|
||||||
|
const int encoder_max_partition_size,
|
||||||
|
const int speculate_max_draft_token_num,
|
||||||
|
const bool causal,
|
||||||
|
const bool is_decoder,
|
||||||
|
const bool enable_prefill,
|
||||||
|
cudaStream_t& stream,
|
||||||
|
paddle::Tensor* out);
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
#include "../append_attention_c16_impl.cuh"
|
||||||
|
|
||||||
|
template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4m3fn>(
|
||||||
|
const AppendAttnMetaData& meta_data,
|
||||||
|
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim]
|
||||||
|
const paddle::Tensor&
|
||||||
|
cache_k, // [max_block_num, num_heads, block_size, head_dim]
|
||||||
|
const paddle::Tensor&
|
||||||
|
cache_v, // [max_block_num, num_heads, head_dim, block_size]
|
||||||
|
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_k_scale, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_v_scale, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_k_zp, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_v_zp, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
shift_bias, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
smooth_weight, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::Tensor& seq_lens_q,
|
||||||
|
const paddle::Tensor& seq_lens_kv,
|
||||||
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& batch_id_per_token,
|
||||||
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
|
const paddle::Tensor& block_table,
|
||||||
|
const paddle::Tensor& batch_ids,
|
||||||
|
const paddle::Tensor& tile_ids_per_batch,
|
||||||
|
const int num_blocks,
|
||||||
|
const int block_shape_q,
|
||||||
|
const int max_seq_len,
|
||||||
|
const int max_dec_len,
|
||||||
|
const float quant_max_bound,
|
||||||
|
const float quant_min_bound,
|
||||||
|
const float in_scale,
|
||||||
|
const int max_partition_size,
|
||||||
|
const int encoder_max_partition_size,
|
||||||
|
const int speculate_max_draft_token_num,
|
||||||
|
const bool causal,
|
||||||
|
const bool is_decoder,
|
||||||
|
const bool enable_prefill,
|
||||||
|
cudaStream_t& stream,
|
||||||
|
paddle::Tensor* out);
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
#include "../append_attention_c16_impl.cuh"
|
||||||
|
|
||||||
|
template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
|
||||||
|
const AppendAttnMetaData& meta_data,
|
||||||
|
const paddle::Tensor& qkv, // [token_num, num_heads, head_dim]
|
||||||
|
const paddle::Tensor&
|
||||||
|
cache_k, // [max_block_num, num_heads, block_size, head_dim]
|
||||||
|
const paddle::Tensor&
|
||||||
|
cache_v, // [max_block_num, num_heads, head_dim, block_size]
|
||||||
|
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_k_scale, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_v_scale, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_k_zp, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
cache_v_zp, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
shift_bias, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::optional<paddle::Tensor>&
|
||||||
|
smooth_weight, // [num_kv_heads, head_dim]
|
||||||
|
const paddle::Tensor& seq_lens_q,
|
||||||
|
const paddle::Tensor& seq_lens_kv,
|
||||||
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
|
const paddle::Tensor& batch_id_per_token,
|
||||||
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
|
const paddle::Tensor& block_table,
|
||||||
|
const paddle::Tensor& batch_ids,
|
||||||
|
const paddle::Tensor& tile_ids_per_batch,
|
||||||
|
const int num_blocks,
|
||||||
|
const int block_shape_q,
|
||||||
|
const int max_seq_len,
|
||||||
|
const int max_dec_len,
|
||||||
|
const float quant_max_bound,
|
||||||
|
const float quant_min_bound,
|
||||||
|
const float in_scale,
|
||||||
|
const int max_partition_size,
|
||||||
|
const int encoder_max_partition_size,
|
||||||
|
const int speculate_max_draft_token_num,
|
||||||
|
const bool causal,
|
||||||
|
const bool is_decoder,
|
||||||
|
const bool enable_prefill,
|
||||||
|
cudaStream_t& stream,
|
||||||
|
paddle::Tensor* out);
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user