mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
Compare commits
153 Commits
copilot/ad
...
develop
Author | SHA1 | Date | |
---|---|---|---|
![]() |
791b101195 | ||
![]() |
af3872215e | ||
![]() |
d14aadf70e | ||
![]() |
81959c7d88 | ||
![]() |
7c919070f7 | ||
![]() |
2b2b645296 | ||
![]() |
3740e33fea | ||
![]() |
70633c6641 | ||
![]() |
1282ebe1b1 | ||
![]() |
6265f4385f | ||
![]() |
59313ed7f9 | ||
![]() |
aa1cc09c5b | ||
![]() |
7b6cb72ab2 | ||
![]() |
3cef851468 | ||
![]() |
17e00d9f5d | ||
![]() |
aa045aa84f | ||
![]() |
79c2c52756 | ||
![]() |
5c6e859681 | ||
![]() |
f40d7c6d65 | ||
![]() |
331c4d2a74 | ||
![]() |
838de53de8 | ||
![]() |
55124f8491 | ||
![]() |
8a964329f4 | ||
![]() |
67e693b18b | ||
![]() |
12a3587cca | ||
![]() |
dd2e844ea3 | ||
![]() |
4ec00df2b0 | ||
![]() |
83d41d23b0 | ||
![]() |
c415885a94 | ||
![]() |
4515ad21e9 | ||
![]() |
0c6f1932c5 | ||
![]() |
87179cb744 | ||
![]() |
e36eccfdad | ||
![]() |
b433a93d9a | ||
![]() |
870364b547 | ||
![]() |
5ff10c8ced | ||
![]() |
18f4977aec | ||
![]() |
7c1fd19f0f | ||
![]() |
8b0ce8e3ab | ||
![]() |
9566ae8827 | ||
![]() |
e8318b7477 | ||
![]() |
3161014e49 | ||
![]() |
44010cee13 | ||
![]() |
f1b5392e20 | ||
![]() |
a1c5d930bb | ||
![]() |
b455fd39f3 | ||
![]() |
d6e59447f5 | ||
![]() |
ec99474e71 | ||
![]() |
62d1c48363 | ||
![]() |
1a6283424e | ||
![]() |
c96a535a5d | ||
![]() |
9082f625ba | ||
![]() |
813befadfa | ||
![]() |
c32aae901f | ||
![]() |
4325b737e7 | ||
![]() |
2c34a557f4 | ||
![]() |
83720da79f | ||
![]() |
772f0156f3 | ||
![]() |
504461b6b5 | ||
![]() |
5532e8a323 | ||
![]() |
5e1f13bd3b | ||
![]() |
c5671d7c09 | ||
![]() |
918ccdb123 | ||
![]() |
9845f0d010 | ||
![]() |
c4830ef24c | ||
![]() |
0b62648924 | ||
![]() |
c86945ef49 | ||
![]() |
da74a5f0b3 | ||
![]() |
718f32a6b0 | ||
![]() |
5c33be5a7d | ||
![]() |
91912cc2e1 | ||
![]() |
cc6e14d2ec | ||
![]() |
24180fba0a | ||
![]() |
ee9d8a840a | ||
![]() |
66a98b44ed | ||
![]() |
a685e5ad35 | ||
![]() |
ddf5606263 | ||
![]() |
c3b8ebeb18 | ||
![]() |
62b8b02e08 | ||
![]() |
98447beb4d | ||
![]() |
618ccdbfba | ||
![]() |
2745f37017 | ||
![]() |
896e3bb606 | ||
![]() |
0d3a57a2c6 | ||
![]() |
b52971749c | ||
![]() |
2adca04f1f | ||
![]() |
f9766f917b | ||
![]() |
2e9e53ff7e | ||
![]() |
c01a756912 | ||
![]() |
cd09913552 | ||
![]() |
67e6d8c691 | ||
![]() |
de8638b1e9 | ||
![]() |
4f8901489c | ||
![]() |
e79a1a7938 | ||
![]() |
d682c97dd3 | ||
![]() |
8e49d99009 | ||
![]() |
83bf1fd5aa | ||
![]() |
b70ca35c0b | ||
![]() |
befe463f01 | ||
![]() |
442543cd6b | ||
![]() |
ed2dcec829 | ||
![]() |
a04365a0c7 | ||
![]() |
03b3d6175d | ||
![]() |
17a27170bc | ||
![]() |
113e330030 | ||
![]() |
69aa2781a1 | ||
![]() |
46911f903d | ||
![]() |
b1b33211e8 | ||
![]() |
9409665713 | ||
![]() |
29ed617f0f | ||
![]() |
b1a5b756a3 | ||
![]() |
4408dc7f67 | ||
![]() |
ef4a1aa2da | ||
![]() |
f213ae1e86 | ||
![]() |
553adb299e | ||
![]() |
958abebeab | ||
![]() |
4871f18dad | ||
![]() |
987609c894 | ||
![]() |
9ac539471d | ||
![]() |
88ea565aba | ||
![]() |
c86b3357ce | ||
![]() |
06f4b49ca3 | ||
![]() |
805f29a06c | ||
![]() |
cab7a633fe | ||
![]() |
58e0785bab | ||
![]() |
8466219ec8 | ||
![]() |
82dab8a91a | ||
![]() |
37f1632732 | ||
![]() |
4859f40b20 | ||
![]() |
2056a428bd | ||
![]() |
850465e8ed | ||
![]() |
a47976e82d | ||
![]() |
abdcef30aa | ||
![]() |
d2ec7f6aa2 | ||
![]() |
fec58639db | ||
![]() |
d2d04c2d5e | ||
![]() |
d60f7c4661 | ||
![]() |
e4c64a71cc | ||
![]() |
2650f58740 | ||
![]() |
2af0f671b1 | ||
![]() |
a7392a0ff9 | ||
![]() |
637d96c6ae | ||
![]() |
7ee100903f | ||
![]() |
684e93269b | ||
![]() |
276f73cf83 | ||
![]() |
d3e4ae3d49 | ||
![]() |
453487d5b0 | ||
![]() |
9d0074a91a | ||
![]() |
c3b2a60fb8 | ||
![]() |
dbab579299 | ||
![]() |
f078a959b6 | ||
![]() |
3b1da6e4dd | ||
![]() |
b3fac5bde1 |
1
.github/workflows/_accuracy_test.yml
vendored
1
.github/workflows/_accuracy_test.yml
vendored
@@ -160,6 +160,7 @@ jobs:
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
pushd tests/ce/deploy
|
||||
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
|
||||
python3.10 deploy.py > dd.log 2>&1 &
|
||||
sleep 3
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
||||
|
4
.github/workflows/_base_test.yml
vendored
4
.github/workflows/_base_test.yml
vendored
@@ -143,7 +143,8 @@ jobs:
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
# python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
python -m pip install paddlepaddle-gpu==3.3.0.dev20250917 -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
@@ -160,6 +161,7 @@ jobs:
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
pushd tests/ce/deploy
|
||||
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
|
||||
python3.10 deploy.py > dd.log 2>&1 &
|
||||
sleep 3
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
||||
|
11
.github/workflows/_build_linux.yml
vendored
11
.github/workflows/_build_linux.yml
vendored
@@ -55,7 +55,7 @@ on:
|
||||
jobs:
|
||||
fd-build:
|
||||
runs-on: [self-hosted, GPU-Build]
|
||||
timeout-minutes: 240
|
||||
timeout-minutes: 360
|
||||
outputs:
|
||||
wheel_path: ${{ steps.set_output.outputs.wheel_path }}
|
||||
steps:
|
||||
@@ -106,7 +106,12 @@ jobs:
|
||||
CARD_ID=$(echo "${runner_name}" | awk -F'-' '{print $NF}')
|
||||
gpu_id=$(echo "$CARD_ID" | fold -w1 | paste -sd,)
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$(dirname "$(dirname "${{ github.workspace }}")")}"
|
||||
IFS='/' read -ra parts <<< "${GITHUB_WORKSPACE}"
|
||||
len=${#parts[@]}
|
||||
CCACHE_DEFAULT_DIR="/$(IFS=/; echo "${parts[*]:1:$((len-5))}")"
|
||||
echo "$CCACHE_DEFAULT_DIR"
|
||||
|
||||
CACHE_DIR="${CACHE_DIR:-$CCACHE_DEFAULT_DIR}"
|
||||
echo "CACHE_DIR is set to ${CACHE_DIR}"
|
||||
if [ ! -f "${CACHE_DIR}/gitconfig" ]; then
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
@@ -127,6 +132,7 @@ jobs:
|
||||
-e "PADDLEVERSION=${PADDLEVERSION}" \
|
||||
-e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \
|
||||
-e "BRANCH_REF=${BRANCH_REF}" \
|
||||
-e "CCACHE_MAXSIZE=50G" \
|
||||
--gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c '
|
||||
if [[ -n "${FD_VERSION}" ]]; then
|
||||
export FASTDEPLOY_VERSION=${FD_VERSION}
|
||||
@@ -134,6 +140,7 @@ jobs:
|
||||
fi
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
chown -R $(whoami) /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
if [[ "${WITH_NIGHTLY_BUILD}" == "ON" ]];then
|
||||
GIT_COMMIT_TIME=$(git --no-pager show -s --format=%ci HEAD)
|
||||
|
2
.github/workflows/_ci_image_build.yml
vendored
2
.github/workflows/_ci_image_build.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
||||
outputs:
|
||||
docker_name_precheck: ${{ steps.docker_build.outputs.docker_name_precheck }}
|
||||
steps:
|
||||
- name: Code Prepare
|
||||
- name: Docker Build
|
||||
id: docker_build
|
||||
shell: bash
|
||||
env:
|
||||
|
1
.github/workflows/_logprob_test_linux.yml
vendored
1
.github/workflows/_logprob_test_linux.yml
vendored
@@ -147,6 +147,7 @@ jobs:
|
||||
--skip install
|
||||
|
||||
cd PaddleTest/framework/ServeTest
|
||||
ps -ef | grep "${FD_ENGINE_QUEUE_PORT}" | grep -v grep | awk "{print \$2}" | xargs -r kill -9
|
||||
python3.10 deploy.py > dd.log 2>&1 &
|
||||
sleep 3
|
||||
curl -X POST http://0.0.0.0:${FLASK_PORT}/start \
|
||||
|
3
.github/workflows/_pre_ce_test.yml
vendored
3
.github/workflows/_pre_ce_test.yml
vendored
@@ -82,6 +82,9 @@ jobs:
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||
FD_ZMQ_RECV_REQUEST_SERVER_PORT=$((42048 + DEVICE_PORT * 100))
|
||||
FD_ZMQ_SEND_RESPONSE_SERVER_PORT=$((42038 + DEVICE_PORT * 100))
|
||||
FD_ZMQ_CONTROL_CMD_SERVER_PORTS=$((42028 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
|
2
.github/workflows/_unit_test_coverage.yml
vendored
2
.github/workflows/_unit_test_coverage.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
|
||||
run_tests_with_coverage:
|
||||
runs-on: [self-hosted, GPU-h1z1-2Cards]
|
||||
timeout-minutes: 60
|
||||
timeout-minutes: 90
|
||||
needs: check_cov_skip
|
||||
if: needs.check_cov_skip.outputs.can-skip != 'true'
|
||||
outputs:
|
||||
|
2
.github/workflows/ce_job.yml
vendored
2
.github/workflows/ce_job.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.ref }}-${{ github.sha }}
|
||||
group: CE-Job-${{ github.ref }}-${{ github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
|
6
.github/workflows/ci_iluvatar.yml
vendored
6
.github/workflows/ci_iluvatar.yml
vendored
@@ -28,18 +28,22 @@ jobs:
|
||||
REPO="https://github.com/${{ github.repository }}.git"
|
||||
FULL_REPO="${{ github.repository }}"
|
||||
REPO_NAME="${FULL_REPO##*/}"
|
||||
BASE_BRANCH="${{ github.base_ref }}"
|
||||
# Clean the repository directory before starting
|
||||
docker run --rm --net=host -v $(pwd):/workspace -w /workspace \
|
||||
-e "REPO_NAME=${REPO_NAME}" \
|
||||
-e "BASE_BRANCH=${BASE_BRANCH}" \
|
||||
${docker_image} /bin/bash -c '
|
||||
if [ -d ${REPO_NAME} ]; then
|
||||
echo "Directory ${REPO_NAME} exists, removing it..."
|
||||
rm -rf ${REPO_NAME}
|
||||
fi
|
||||
'
|
||||
git config --global http.proxy "http://61.151.249.150:33128"
|
||||
git config --global https.proxy "http://61.151.249.150:33128"
|
||||
git config --global user.name "FastDeployCI"
|
||||
git config --global user.email "fastdeploy_ci@example.com"
|
||||
git clone ${REPO} ${REPO_NAME}
|
||||
git clone --recursive ${REPO} ${REPO_NAME} -b ${BASE_BRANCH}
|
||||
cd FastDeploy
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
git fetch origin pull/${{ github.event.pull_request.number }}/head:pr/${{ github.event.pull_request.number }}
|
||||
|
2
.github/workflows/ci_image_update.yml
vendored
2
.github/workflows/ci_image_update.yml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.ref }}-${{ github.sha }}
|
||||
group: CI-Images-Build-${{ github.ref }}-${{ github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
|
||||
|
2
.github/workflows/publish_job.yml
vendored
2
.github/workflows/publish_job.yml
vendored
@@ -13,7 +13,7 @@ on:
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.ref }}-${{ github.sha }}
|
||||
group: Publish-Job-${{ github.ref }}-${{ github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
|
||||
|
10
.gitmodules
vendored
Normal file
10
.gitmodules
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
[submodule "custom_ops/third_party/DeepGEMM"]
|
||||
path = custom_ops/third_party/DeepGEMM
|
||||
url = https://github.com/deepseek-ai/DeepGEMM.git
|
||||
ignore = all
|
||||
[submodule "custom_ops/third_party/cutlass"]
|
||||
path = custom_ops/third_party/cutlass
|
||||
url = https://github.com/NVIDIA/cutlass.git
|
||||
[submodule "custom_ops/third_party/nlohmann_json"]
|
||||
path = custom_ops/third_party/nlohmann_json
|
||||
url = https://github.com/nlohmann/json.git
|
@@ -43,7 +43,7 @@ English | [简体中文](README_CN.md)
|
||||
- 🤝 **OpenAI API Server and vLLM Compatible**: One-command deployment with [vLLM](https://github.com/vllm-project/vllm/) interface compatibility.
|
||||
- 🧮 **Comprehensive Quantization Format Support**: W8A16, W8A8, W4A16, W4A8, W2A16, FP8, and more.
|
||||
- ⏩ **Advanced Acceleration Techniques**: Speculative decoding, Multi-Token Prediction (MTP) and Chunked Prefill.
|
||||
- 🖥️ **Multi-Hardware Support**: NVIDIA GPU, Kunlunxin XPU, Hygon DCU, Ascend NPU, Iluvatar GPU, Enflame GCU, MetaX GPU etc.
|
||||
- 🖥️ **Multi-Hardware Support**: NVIDIA GPU, Kunlunxin XPU, Hygon DCU, Ascend NPU, Iluvatar GPU, Enflame GCU, MetaX GPU, Intel Gaudi etc.
|
||||
|
||||
## Requirements
|
||||
|
||||
@@ -60,6 +60,7 @@ FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**,
|
||||
- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md)
|
||||
- [Hygon DCU](./docs/get_started/installation/hygon_dcu.md)
|
||||
- [MetaX GPU](./docs/get_started/installation/metax_gpu.md)
|
||||
- [Intel Gaudi](./docs/get_started/installation/intel_gaudi.md)
|
||||
|
||||
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU are currently under development and testing. Stay tuned for updates!
|
||||
|
||||
|
@@ -41,7 +41,7 @@
|
||||
- 🤝 **OpenAI API服务与vLLM兼容**:单命令部署,兼容[vLLM](https://github.com/vllm-project/vllm/)接口
|
||||
- 🧮 **全量化格式支持**:W8A16、W8A8、W4A16、W4A8、W2A16、FP8等
|
||||
- ⏩ **高级加速技术**:推测解码、多令牌预测(MTP)及分块预填充
|
||||
- 🖥️ **多硬件支持**:NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU等
|
||||
- 🖥️ **多硬件支持**:NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU、英特尔Gaudi等
|
||||
|
||||
## 要求
|
||||
|
||||
@@ -58,6 +58,7 @@ FastDeploy 支持在**英伟达(NVIDIA)GPU**、**昆仑芯(Kunlunxin)XPU
|
||||
- [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md)
|
||||
- [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md)
|
||||
- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md)
|
||||
- [英特尔 Gaudi](./docs/zh/get_started/installation/intel_gaudi.md)
|
||||
|
||||
**注意:** 我们正在积极拓展硬件支持范围。目前,包括昇腾(Ascend)NPU 等其他硬件平台正在开发测试中。敬请关注更新!
|
||||
|
||||
|
@@ -98,7 +98,7 @@ def main(args):
|
||||
raise ValueError("--max_concurrency should be same length as --s_itl_base_model")
|
||||
|
||||
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
|
||||
# Wramup
|
||||
# Warmup
|
||||
print("Starting warmup...")
|
||||
with open(os.devnull, "w") as f:
|
||||
with contextlib.redirect_stdout(f):
|
||||
|
@@ -965,7 +965,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="vllm",
|
||||
default="openai-chat",
|
||||
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
|
5
benchmarks/yaml/GLM45-air-32k-bf16.yaml
Normal file
5
benchmarks/yaml/GLM45-air-32k-bf16.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
tensor_parallel_size: 4
|
||||
use_cudagraph: True
|
||||
load_choices: "default_v1"
|
6
benchmarks/yaml/GLM45-air-32k-wfp8afp8.yaml
Normal file
6
benchmarks/yaml/GLM45-air-32k-wfp8afp8.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
tensor_parallel_size: 4
|
||||
use_cudagraph: True
|
||||
load_choices: "default_v1"
|
||||
quantization: wfp8afp8
|
@@ -6,3 +6,4 @@ tensor_parallel_size: 8
|
||||
max_num_batched_tokens: 4096
|
||||
max_num_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
||||
quantization: wint4
|
||||
|
6
benchmarks/yaml/eb45-128k-wint4-tp1-plas.yaml
Normal file
6
benchmarks/yaml/eb45-128k-wint4-tp1-plas.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
tensor_parallel_size: 1
|
||||
max_model_len: 131072
|
||||
max_num_seqs: 32
|
||||
quantization: wint4
|
||||
max_num_batched_tokens: 8192
|
||||
plas_attention_config: '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}'
|
@@ -6,3 +6,4 @@ tensor_parallel_size: 8
|
||||
max_num_batched_tokens: 4096
|
||||
max_num_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
||||
quantization: wint8
|
||||
|
5
benchmarks/yaml/eb45-32k-wint2-tp4.yaml
Normal file
5
benchmarks/yaml/eb45-32k-wint2-tp4.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 256
|
||||
kv_cache_ratio: 0.75
|
||||
tensor_parallel_size: 4
|
||||
gpu_memory_utilization: 0.9
|
@@ -13,3 +13,4 @@ pd_comm_port: "2334"
|
||||
max_num_batched_tokens: 384
|
||||
max_num_partial_prefills: 3
|
||||
max_long_partial_prefills: 3
|
||||
quantization: wint4
|
||||
|
@@ -10,3 +10,4 @@ engine_worker_queue_port: 6677
|
||||
cache_transfer_protocol: "rdma,ipc"
|
||||
rdma_comm_ports: "7675,7676,7677,7678"
|
||||
pd_comm_port: "2333"
|
||||
quantization: wint4
|
||||
|
11
benchmarks/yaml/eb45-vl-128k-wint4-h800-tp8.yaml
Normal file
11
benchmarks/yaml/eb45-vl-128k-wint4-h800-tp8.yaml
Normal file
@@ -0,0 +1,11 @@
|
||||
enable_mm: True
|
||||
max_model_len: 131072
|
||||
max_num_seqs: 56
|
||||
gpu_memory_utilization: 0.8
|
||||
kv_cache_ratio: 0.8
|
||||
tensor_parallel_size: 8
|
||||
quantization: wint4
|
||||
limit_mm_per_prompt: '{"image": 100, "video": 100}'
|
||||
enable_chunked_prefill: True
|
||||
max_num_batched_tokens: 384
|
||||
reasoning_parser: ernie-45-vl
|
@@ -1,7 +1,7 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 36
|
||||
gpu_memory_utilization: 0.95
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.8
|
||||
tensor_parallel_size: 8
|
||||
quantization: wint8
|
||||
|
@@ -1,7 +1,7 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 36
|
||||
gpu_memory_utilization: 0.8
|
||||
gpu_memory_utilization: 0.85
|
||||
kv_cache_ratio: 0.8
|
||||
tensor_parallel_size: 8
|
||||
quantization: wint8
|
||||
|
9
benchmarks/yaml/eb45-vl-lite-32k-bf16-a800-tp1.yaml
Normal file
9
benchmarks/yaml/eb45-vl-lite-32k-bf16-a800-tp1.yaml
Normal file
@@ -0,0 +1,9 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 1
|
||||
enable_chunked_prefill: True
|
||||
max_num_batched_tokens: 384
|
||||
reasoning_parser: ernie-45-vl
|
10
benchmarks/yaml/eb45-vl-lite-32k-wint4-a800-tp1.yaml
Normal file
10
benchmarks/yaml/eb45-vl-lite-32k-wint4-a800-tp1.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 1
|
||||
enable_chunked_prefill: True
|
||||
max_num_batched_tokens: 384
|
||||
quantization: wint4
|
||||
reasoning_parser: ernie-45-vl
|
10
benchmarks/yaml/eb45-vl-lite-32k-wint8-a800-tp1.yaml
Normal file
10
benchmarks/yaml/eb45-vl-lite-32k-wint8-a800-tp1.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
enable_mm: True
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 128
|
||||
gpu_memory_utilization: 0.9
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 1
|
||||
enable_chunked_prefill: True
|
||||
max_num_batched_tokens: 384
|
||||
quantization: wint8
|
||||
reasoning_parser: ernie-45-vl
|
@@ -2,7 +2,7 @@ top_p: 0.95
|
||||
temperature: 0.6
|
||||
metadata:
|
||||
min_tokens: 1
|
||||
max_tokens: 65535
|
||||
max_tokens: 12288
|
||||
repetition_penalty: 1.0
|
||||
frequency_penalty: 0
|
||||
presence_penalty: 0
|
1
benchmarks/yaml/request_yaml/eb45-vl-128k.yaml
Normal file
1
benchmarks/yaml/request_yaml/eb45-vl-128k.yaml
Normal file
@@ -0,0 +1 @@
|
||||
max_tokens: 131071
|
1
benchmarks/yaml/request_yaml/eb45-vl-32k.yaml
Normal file
1
benchmarks/yaml/request_yaml/eb45-vl-32k.yaml
Normal file
@@ -0,0 +1 @@
|
||||
max_tokens: 12288
|
8
benchmarks/yaml/request_yaml/x1-128k.yaml
Normal file
8
benchmarks/yaml/request_yaml/x1-128k.yaml
Normal file
@@ -0,0 +1,8 @@
|
||||
top_p: 0.95
|
||||
temperature: 0.6
|
||||
metadata:
|
||||
min_tokens: 1
|
||||
max_tokens: 131071
|
||||
repetition_penalty: 1.0
|
||||
frequency_penalty: 0
|
||||
presence_penalty: 0
|
6
benchmarks/yaml/x1-a3b-128k-wint8-h800-tp1.yaml
Normal file
6
benchmarks/yaml/x1-a3b-128k-wint8-h800-tp1.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
tensor_parallel_size: 1
|
||||
max_model_len: 131072
|
||||
max_num_seqs: 32
|
||||
reasoning_parser: ernie_x1
|
||||
tool_call_parser: ernie_x1
|
||||
load_choices: "default_v1"
|
14
build.sh
14
build.sh
@@ -128,6 +128,12 @@ function copy_ops(){
|
||||
echo -e "MACA ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
is_intel_hpu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('intel_hpu'))"`
|
||||
if [ "$is_intel_hpu" = "True" ]; then
|
||||
DEVICE_TYPE="intel-hpu"
|
||||
echo -e "intel_hpu ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
DEVICE_TYPE="cpu"
|
||||
cd ../../../../
|
||||
@@ -143,9 +149,9 @@ function build_and_install_ops() {
|
||||
TMP_DIR_REAL_PATH=`readlink -f ${OPS_TMP_DIR}`
|
||||
is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"`
|
||||
if [ "$is_xpu" = "True" ]; then
|
||||
cd xpu_ops/src
|
||||
cd xpu_ops
|
||||
bash build.sh ${TMP_DIR_REAL_PATH}
|
||||
cd ../..
|
||||
cd ..
|
||||
elif [ "$FD_CPU_USE_BF16" == "true" ]; then
|
||||
if [ "$FD_BUILDING_ARCS" == "" ]; then
|
||||
FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
|
||||
@@ -159,7 +165,9 @@ function build_and_install_ops() {
|
||||
else
|
||||
FD_BUILDING_ARCS=${FD_BUILDING_ARCS} ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
|
||||
fi
|
||||
find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \;
|
||||
if [ -d "${OPS_TMP_DIR}" ]; then
|
||||
find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \;
|
||||
fi
|
||||
else
|
||||
echo "Error: Invalid parameter '$FD_CPU_USE_BF16'. Please use true or false."
|
||||
exit 1
|
||||
|
@@ -428,6 +428,142 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_neox_partial_rope_kernel(
|
||||
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
|
||||
const float* __restrict__ sin_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int rotary_dim,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt,
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadKVT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
|
||||
LoadT left_vec, right_vec;
|
||||
LoadBiasT left_bias_vec, right_bias_vec;
|
||||
LoadKVT left_cache_vec, right_cache_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int half_head_size = head_size / 2;
|
||||
const int half_rotary_dim = rotary_dim / 2;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
|
||||
const int64_t half_hidden_size = hidden_size / 2;
|
||||
// const int64_t offset = 2 * hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int ori_bi = linear_index / half_hidden_size;
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
if (hi < num_heads && h_bias >= half_rotary_dim){
|
||||
continue;
|
||||
}
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
uint32_t ori_idx_left =
|
||||
start_token_idx * hidden_size + hi * head_size + h_bias;
|
||||
uint32_t ori_idx_right = ori_idx_left + half_head_size;
|
||||
if (hi < num_heads){
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}else if (hi < num_heads + kv_num_heads){
|
||||
if (h_bias < half_rotary_dim){
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}else{
|
||||
ori_idx_left = ori_idx_left + half_rotary_dim;
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}
|
||||
}
|
||||
|
||||
Load<T, VecSize>(&qkv[ori_idx_left], &left_vec);
|
||||
Load<T, VecSize>(&qkv[ori_idx_right], &right_vec);
|
||||
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
// q k rope
|
||||
const uint32_t emb_idx = write_seq_id * half_rotary_dim + h_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
|
||||
if (h_bias < half_rotary_dim){
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
// rope
|
||||
float input_left = static_cast<float>(left_vec[i]);
|
||||
float input_right = static_cast<float>(right_vec[i]);
|
||||
if (hi < num_heads + kv_num_heads && h_bias < half_rotary_dim) {
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
} else {
|
||||
left_bias_vec[i] = static_cast<T>(input_left);
|
||||
right_bias_vec[i] = static_cast<T>(input_right);
|
||||
}
|
||||
}
|
||||
if (hi < num_heads) {
|
||||
// write q
|
||||
Store<T, VecSize>(left_bias_vec, &qkv_out[ori_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &qkv_out[ori_idx_right]);
|
||||
} else {
|
||||
// write k/v
|
||||
const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads;
|
||||
uint32_t tgt_idx_left =
|
||||
block_idx * kv_num_heads * block_size * head_size +
|
||||
kv_head_idx * block_size * head_size + block_offset * head_size +
|
||||
h_bias;
|
||||
uint32_t tgt_idx_right = tgt_idx_left + half_head_size;
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
if (h_bias < half_rotary_dim) {
|
||||
tgt_idx_right = tgt_idx_left + half_rotary_dim;
|
||||
}else{
|
||||
tgt_idx_left = tgt_idx_left + half_rotary_dim;
|
||||
tgt_idx_right = tgt_idx_left + half_rotary_dim;
|
||||
}
|
||||
Store<T, VecSize>(left_bias_vec, &key_cache[tgt_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &key_cache[tgt_idx_right]);
|
||||
} else {
|
||||
Store<T, VecSize>(left_bias_vec, &value_cache[tgt_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &value_cache[tgt_idx_right]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
@@ -913,7 +1049,7 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
local_max = __hmax(local_max, __habs(out_vec2[i]));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int m_offset = 16; m_offset > 1; m_offset /= 2) {
|
||||
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
|
||||
local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
|
||||
}
|
||||
|
||||
|
@@ -94,6 +94,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int rotary_dim,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const cudaStream_t& stream,
|
||||
@@ -133,7 +134,29 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
||||
if (rotary_dim < dim_head){
|
||||
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}else{
|
||||
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
@@ -152,6 +175,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
@@ -516,11 +540,20 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const float* cos_emb =
|
||||
rotary_embs ? rotary_embs.get().data<float>() : nullptr;
|
||||
const float* sin_emb;
|
||||
int rotary_dim = dim_head;
|
||||
if (rotary_embs) {
|
||||
sin_emb =
|
||||
use_neox_rotary_style
|
||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
|
||||
if(rotary_dim < dim_head){
|
||||
if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight || k_norm_weight|| cache_quant_type_str != "none"){
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"partial_rotary_factor < 1.0 only supports neox_rotary_style=True, qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and cache_quant_type_str is 'none'."));
|
||||
}
|
||||
sin_emb = rotary_embs.get().data<float>() + max_seq_len * rotary_dim / 2;
|
||||
}
|
||||
}
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
@@ -609,6 +642,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
|
@@ -900,6 +900,74 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void GQANeoxVariableLengthPartialRotaryKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const float *qkv_out_scales,
|
||||
const T *qkv_biases,
|
||||
T *qkv_out,
|
||||
const int64_t elem_cnt,
|
||||
const int q_num_head,
|
||||
const int kv_num_head,
|
||||
const int seq_len,
|
||||
const int head_dim,
|
||||
const int rotary_dim,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
LoadT left_vec;
|
||||
LoadT right_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int rotary_dim_half = rotary_dim / 2;
|
||||
const int offset = (q_num_head + kv_num_head) * rotary_dim_half;
|
||||
for (int64_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / rotary_dim_half;
|
||||
const int h_bias = bias % rotary_dim_half;
|
||||
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * rotary_dim_half + h_bias;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * head_dim * seq_len * 2 : emb_idx;
|
||||
const int base_idx_left =
|
||||
token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim +
|
||||
h_bias;
|
||||
const int base_idx_right = base_idx_left + rotary_dim_half;
|
||||
|
||||
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
|
||||
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
const float input_left = static_cast<float>(left_vec[i]);
|
||||
const float input_right = static_cast<float>(right_vec[i]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
right_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
Store<T, VecSize>(left_vec, &qkv_out[base_idx_left]);
|
||||
Store<T, VecSize>(right_vec, &qkv_out[base_idx_right]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void cache_kernel(
|
||||
const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads,
|
||||
@@ -936,7 +1004,8 @@ __global__ void cache_kernel(
|
||||
const uint32_t qkv_bias = bias % hidden_size;
|
||||
const uint32_t hi = qkv_bias / head_size;
|
||||
const uint32_t h_bias = qkv_bias % head_size;
|
||||
const uint32_t ori_bi = batch_id_per_token[token_idx];
|
||||
const int32_t ori_bi = batch_id_per_token[token_idx];
|
||||
if (ori_bi == -1) continue; // skip batch_id_per_token[token_idx]=-1
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
@@ -2160,6 +2229,7 @@ void gqa_rotary_qk_variable(
|
||||
const int seq_len,
|
||||
const int input_output_len,
|
||||
const int dim_head,
|
||||
const int rotary_dim,
|
||||
const cudaStream_t &stream,
|
||||
bool use_neox_style = false,
|
||||
bool rope_3d = false) {
|
||||
@@ -2240,7 +2310,38 @@ void gqa_rotary_qk_variable(
|
||||
dim_head,
|
||||
rope_3d);
|
||||
} else {
|
||||
GQANeoxVariableLengthRotaryKernel<T, PackSize>
|
||||
if (rotary_dim < dim_head){
|
||||
PD_CHECK((rotary_dim / 2) % PackSize == 0);
|
||||
elem_nums =
|
||||
qkv_out_scales
|
||||
? token_num * (num_heads + 2 * kv_num_heads) * rotary_dim
|
||||
: token_num * (num_heads + kv_num_heads) * rotary_dim; // for all q k v
|
||||
if (use_neox_style) {
|
||||
elem_nums /= 2;
|
||||
}
|
||||
const int pack_num_new = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num_new, &grid_size);
|
||||
GQANeoxVariableLengthPartialRotaryKernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
rotary_emb + input_output_len * rotary_dim / 2,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
rope_3d);
|
||||
}else{
|
||||
GQANeoxVariableLengthRotaryKernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
@@ -2258,6 +2359,7 @@ void gqa_rotary_qk_variable(
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -55,9 +55,19 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
auto head_dim = meta_data.head_dims;
|
||||
bool is_scale_channel_wise = false;
|
||||
int rotary_dim = head_dim;
|
||||
if (cache_k_scale && cache_k_scale.get().dims()[0] == head_dim * kv_num_heads) {
|
||||
is_scale_channel_wise = true;
|
||||
}
|
||||
if (rotary_embs){
|
||||
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
|
||||
if(rotary_dim < head_dim){
|
||||
if (!use_neox_style || q_norm_weight || k_norm_weight || num_heads == kv_num_heads || is_scale_channel_wise){
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"partial_rotary_factor < 1.0 only supports use_neox_rotary_style=True, q_norm_weight/k_norm_weight) is None, GQA and is_scale_channel_wise=false."));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) {
|
||||
@@ -125,6 +135,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
max_seq_len,
|
||||
rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2],
|
||||
head_dim,
|
||||
rotary_dim,
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d);
|
||||
|
@@ -11,10 +11,11 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/core/memory/memcpy.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void
|
||||
@@ -116,6 +117,93 @@ void GetMaxLen(const paddle::Tensor &seq_lens_tensor,
|
||||
max_len_tensor.data<int>(), batch_size);
|
||||
}
|
||||
|
||||
template <uint32_t config_size>
|
||||
__global__ void search_chunk_size_for_mla(
|
||||
const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
int *__restrict__ num_blocks_x,
|
||||
int *__restrict__ res_chunk_size,
|
||||
const int bsz,
|
||||
const int set_chunk_size,
|
||||
const int block_size,
|
||||
const int sm_cout) {
|
||||
const uint32_t conf_id = threadIdx.x;
|
||||
int gridx = 0;
|
||||
if (set_chunk_size > 0 && conf_id == 0) {
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
int seq_len_encoder = seq_lens_encoder[bid];
|
||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
||||
if (seq_len == 0 || seq_len_encoder > 0) continue;
|
||||
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, set_chunk_size);
|
||||
gridx += loop_times;
|
||||
}
|
||||
*num_blocks_x = gridx;
|
||||
*res_chunk_size = set_chunk_size;
|
||||
} else if (conf_id < config_size) {
|
||||
__shared__ int gridx_shared[config_size];
|
||||
// chunk_size is a multiple of 64
|
||||
const int chunk_size = block_size << conf_id;
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
int seq_len_encoder = seq_lens_encoder[bid];
|
||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
||||
if (seq_len == 0 || seq_len_encoder > 0) continue;
|
||||
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
|
||||
gridx += loop_times;
|
||||
}
|
||||
gridx_shared[conf_id] = gridx;
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
uint32_t res_id = 0;
|
||||
uint32_t max_last_wave_block = 0;
|
||||
for (uint32_t i = 1; i < config_size; i++) {
|
||||
uint32_t last_wave_block = gridx_shared[i] % sm_cout;
|
||||
if (last_wave_block >= max_last_wave_block) {
|
||||
res_id = i;
|
||||
max_last_wave_block = last_wave_block;
|
||||
}
|
||||
}
|
||||
*num_blocks_x = gridx_shared[res_id];
|
||||
*res_chunk_size = block_size << res_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void split_block_for_mla(const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
int *__restrict__ batch_ids,
|
||||
int *__restrict__ tile_ids_per_batch,
|
||||
const int bsz,
|
||||
const int chunk_size) {
|
||||
if (threadIdx.x == 0) {
|
||||
int index = 0;
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
int seq_len_encoder = seq_lens_encoder[bid];
|
||||
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
|
||||
|
||||
if (seq_len == 0) continue;
|
||||
|
||||
int loop_times;
|
||||
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
|
||||
if (seq_len_encoder > 0) {
|
||||
loop_times = 0;
|
||||
}
|
||||
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
|
||||
batch_ids[index] = bid;
|
||||
tile_ids_per_batch[index++] = tile_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void split_q_block(const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
int *__restrict__ batch_ids,
|
||||
@@ -197,7 +285,9 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_device, // Inplace
|
||||
paddle::Tensor &decoder_chunk_size_device, // Inplace
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
@@ -230,8 +320,6 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
int max_system_len = max_len_cpu_ptr[6];
|
||||
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
|
||||
|
||||
|
||||
|
||||
auto max_len_kv =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
|
||||
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
|
||||
@@ -241,6 +329,106 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
|
||||
max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);
|
||||
|
||||
// decoder
|
||||
if (max_dec_len_this_time > 0) {
|
||||
const bool mla_use_tensorcore = GetMlaUseTensorcore();
|
||||
if (mla_use_tensorcore && group_size <= 64) {
|
||||
const int set_chunk_size = get_mla_dec_chunk_size(bsz);
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int sm_cout;
|
||||
cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device);
|
||||
constexpr int config_size =
|
||||
12; // search space for chunk size:[64, 128, 256, ... 131072]
|
||||
|
||||
search_chunk_size_for_mla<config_size>
|
||||
<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
decoder_num_blocks_device.data<int>(),
|
||||
decoder_chunk_size_device.data<int>(),
|
||||
bsz,
|
||||
set_chunk_size,
|
||||
block_size,
|
||||
sm_cout);
|
||||
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
auto decoder_chunk_size_cpu =
|
||||
decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false);
|
||||
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
|
||||
|
||||
// NOTE: (changwenbin) When using auto_chunk,
|
||||
// decode_max_tile_size must take into account the maximum case, where * 1024 can cover 128K.
|
||||
// const uint32_t decoder_batch_shape = seq_lens_decoder.dims()[0] * 1024;
|
||||
|
||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape =
|
||||
bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_batch_ids.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
|
||||
0,
|
||||
decoder_batch_shape * sizeof(int32_t),
|
||||
stream));
|
||||
|
||||
|
||||
split_block_for_mla<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
bsz,
|
||||
chunk_size);
|
||||
|
||||
} else {
|
||||
// Note:(changwenbin)In order to adapt to cudagraph, the maximum value should be taken here
|
||||
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape = bsz * 1024 * decoder_max_tile_size_per_bs_q;
|
||||
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
split_q_block<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_device.data<int>(),
|
||||
bsz,
|
||||
decoder_block_shape_q,
|
||||
group_size);
|
||||
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
}
|
||||
} else {
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
|
||||
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
|
||||
decoder_num_blocks_cpu.copy_(
|
||||
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
|
||||
}
|
||||
|
||||
// encoder
|
||||
if (max_enc_len_this_time > 0) {
|
||||
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
|
||||
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
|
||||
@@ -272,28 +460,6 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
if (max_just_dec_len_this_time > 0) {
|
||||
// Clear buffer
|
||||
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
auto decoder_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
split_q_block<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_x.data<int>(),
|
||||
bsz,
|
||||
decoder_block_shape_q,
|
||||
group_size);
|
||||
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
@@ -303,7 +469,9 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
"seq_lens_this_time",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks_x_cpu",
|
||||
"decoder_num_blocks_cpu",
|
||||
"decoder_num_blocks_device",
|
||||
"decoder_chunk_size_device",
|
||||
"max_len_tensor_cpu",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
|
@@ -46,7 +46,8 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
const int gqa_group_size,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
const float rms_norm_eps,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
@@ -109,8 +110,9 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
|
@@ -41,7 +41,8 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
const bool use_neox_style,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
const float rms_norm_eps,
|
||||
const bool rope_3d) {
|
||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||
const uint32_t elem_nums =
|
||||
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||
@@ -53,7 +54,6 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
|
||||
if (use_neox_style) {
|
||||
PD_THROW(
|
||||
"append_speculate_cache_rope_qk_norm not support neox rope yet");
|
||||
@@ -82,7 +82,8 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
kv_num_heads,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -426,7 +427,6 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||
}
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope_qk_norm(
|
||||
@@ -457,11 +457,13 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
use_neox_rotary_style,
|
||||
reinterpret_cast<const float*>(q_norm_weight.get().data<float>()),
|
||||
reinterpret_cast<const float*>(k_norm_weight.get().data<float>()),
|
||||
rms_norm_eps);
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||
}
|
||||
|
||||
} else {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope(
|
||||
|
@@ -63,7 +63,7 @@ std::vector<paddle::Tensor> AppendAttention(
|
||||
const paddle::Tensor &kv_num_blocks,
|
||||
const paddle::Tensor &decoder_batch_ids,
|
||||
const paddle::Tensor &decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor &decoder_num_blocks,
|
||||
const paddle::Tensor &decoder_num_blocks_cpu,
|
||||
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
|
||||
const paddle::optional<paddle::Tensor> &rotary_embs,
|
||||
const paddle::optional<paddle::Tensor> &attn_mask,
|
||||
@@ -105,7 +105,7 @@ void AppendAttentionWithOutput(
|
||||
const paddle::Tensor &kv_num_blocks,
|
||||
const paddle::Tensor &decoder_batch_ids,
|
||||
const paddle::Tensor &decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor &decoder_num_blocks,
|
||||
const paddle::Tensor &decoder_num_blocks_cpu,
|
||||
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
|
||||
paddle::Tensor &fmha_out,
|
||||
const paddle::optional<paddle::Tensor> &rotary_embs,
|
||||
@@ -305,7 +305,9 @@ void GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &decoder_num_blocks_device, // Inplace
|
||||
paddle::Tensor &decoder_chunk_size_device, // Inplace
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
@@ -414,8 +416,8 @@ std::vector<paddle::Tensor> MoEDeepGEMMDePermute(
|
||||
const paddle::Tensor &topk_idx, const paddle::Tensor &topk_weights);
|
||||
|
||||
void TextImageIndexOut(const paddle::Tensor &token_type_ids,
|
||||
const paddle::Tensor &text_input,
|
||||
const paddle::Tensor &image_input);
|
||||
paddle::Tensor &text_input,
|
||||
paddle::Tensor &image_input);
|
||||
|
||||
void TextImageGatherScatter(paddle::Tensor &input, paddle::Tensor &text_input,
|
||||
paddle::Tensor &image_input,
|
||||
@@ -473,23 +475,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& decoder_num_blocks_device,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
@@ -574,6 +571,7 @@ std::vector<paddle::Tensor> NoauxTc(
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
bool renormalize,
|
||||
float routed_scaling_factor);
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
@@ -625,6 +623,8 @@ int64_t open_mem_handle(paddle::Tensor& mem_handle);
|
||||
|
||||
void free_shared_buffer(int64_t buffer);
|
||||
|
||||
void clear_ipc_handles(int64_t _fa);
|
||||
|
||||
// speculative decoding Kernel
|
||||
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||
const paddle::Tensor& input_ids,
|
||||
@@ -1231,6 +1231,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer");
|
||||
|
||||
m.def("clear_ipc_handles", &clear_ipc_handles, "clear_ipc_handles");
|
||||
|
||||
m.def("open_mem_handle", &open_mem_handle, "open_mem_handle");
|
||||
|
||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta");
|
||||
|
@@ -122,10 +122,14 @@ void register_graph_buffers(fptr_t _fa,
|
||||
for (int i = 0; i < handles.size(); i++) {
|
||||
bytes.emplace_back(handles[i].begin(), handles[i].end());
|
||||
}
|
||||
bytes.reserve(handles.size());
|
||||
fa->register_graph_buffers(bytes, offsets);
|
||||
}
|
||||
|
||||
void clear_ipc_handles(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
fa->clear_ipc_handles();
|
||||
}
|
||||
|
||||
std::tuple<fptr_t, paddle::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size) {
|
||||
|
||||
|
@@ -303,7 +303,7 @@ class CustomAllreduce {
|
||||
bool full_nvlink_;
|
||||
|
||||
RankSignals sg_;
|
||||
// Stores an map from a pointer to its peer pointters from all ranks.
|
||||
// Stores an map from a pointer to its peer pointers from all ranks.
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
@@ -517,10 +517,15 @@ class CustomAllreduce {
|
||||
#undef KL
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
void clear_ipc_handles(){
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CUDACHECK(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
ipc_handles_.clear();
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
clear_ipc_handles();
|
||||
}
|
||||
};
|
||||
} // namespace paddle
|
||||
|
@@ -59,6 +59,15 @@ inline uint32_t get_cascade_attention_num_threads() {
|
||||
inline bool get_mla_use_tensorcore() {
|
||||
static const char* mla_use_tensorcore_env = std::getenv("FLAGS_mla_use_tensorcore");
|
||||
static const uint32_t mla_use_tensorcore =
|
||||
mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env));
|
||||
mla_use_tensorcore_env == nullptr ? 0 : std::stoul(std::string(mla_use_tensorcore_env));
|
||||
return mla_use_tensorcore != 0 ? true : false;
|
||||
}
|
||||
inline int get_mla_dec_chunk_size(int bsz) {
|
||||
static const char* mla_dec_chunk_size_env =
|
||||
std::getenv("FLAGS_mla_dec_chunk_size");
|
||||
static const int mla_dec_chunk_size =
|
||||
mla_dec_chunk_size_env == nullptr
|
||||
? -1
|
||||
: std::stoi(std::string(mla_dec_chunk_size_env));
|
||||
return bsz > 1 ? mla_dec_chunk_size : 64;
|
||||
}
|
||||
|
@@ -39,9 +39,6 @@ void GetOutputTopK(const paddle::Tensor& x,
|
||||
int k,
|
||||
int64_t rank_id,
|
||||
bool wait_flag) {
|
||||
if (rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
static struct msgdata msg_rcv;
|
||||
int msg_queue_id = 1;
|
||||
|
@@ -132,7 +132,7 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_padding_offset)
|
||||
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
|
||||
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
|
||||
.Outputs({"x_remove_padding",
|
||||
"batch_id_per_token",
|
||||
"cu_seqlens_q",
|
||||
|
@@ -14,6 +14,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#ifndef PADDLE_WITH_COREX
|
||||
#include "glog/logging.h"
|
||||
#endif
|
||||
@@ -151,6 +153,34 @@ inline int GetGPUComputeCapability(int id) {
|
||||
|
||||
#endif
|
||||
|
||||
#ifndef FP8_E4M3_MAX
|
||||
#define FP8_E4M3_MAX 448.0
|
||||
#endif
|
||||
|
||||
#ifndef DISPATCH_FLOAT_FP6_DTYPE
|
||||
#define DISPATCH_FLOAT_FP6_DTYPE(pd_dtype, c_type, ...) \
|
||||
switch (pd_dtype) { \
|
||||
case phi::DataType::FLOAT32: { \
|
||||
using c_type = float; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case phi::DataType::BFLOAT16: { \
|
||||
using c_type = phi::dtype::bfloat16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case phi::DataType::FLOAT16: { \
|
||||
using c_type = phi::dtype::float16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16]."); \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1)
|
||||
return num;
|
||||
@@ -193,11 +223,13 @@ public:
|
||||
typedef uint8_t data_t;
|
||||
};
|
||||
|
||||
#ifndef PADDLE_WITH_COREX
|
||||
template <> class PDTraits<paddle::DataType::FLOAT8_E4M3FN> {
|
||||
public:
|
||||
typedef __nv_fp8_e4m3 DataType;
|
||||
typedef paddle::float8_e4m3fn data_t;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T, int Size> struct alignas(sizeof(T) * Size) AlignedVector {
|
||||
T val[Size];
|
||||
@@ -563,3 +595,36 @@ inline int GetSMVersion() {
|
||||
return sm_version;
|
||||
|
||||
}
|
||||
|
||||
inline bool GetMlaUseTensorcore() {
|
||||
static const bool flags_mla_use_tensorcore = get_mla_use_tensorcore();
|
||||
static const bool enable_mla_tensorcore = GetSMVersion() >= 90 ? true : false;
|
||||
const bool mla_use_tensorcore =
|
||||
flags_mla_use_tensorcore && enable_mla_tensorcore;
|
||||
return mla_use_tensorcore;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float warpReduceMax(float value) {
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 16));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 8));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 4));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 2));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 1));
|
||||
return value;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float blockReduceMax(float value) {
|
||||
static __shared__ float warpLevelMaxs[WARP_SIZE];
|
||||
const int laneId = threadIdx.x % WARP_SIZE;
|
||||
const int warpId = threadIdx.x / WARP_SIZE;
|
||||
|
||||
value = warpReduceMax(value);
|
||||
|
||||
if (laneId == 0) warpLevelMaxs[warpId] = value;
|
||||
__syncthreads();
|
||||
|
||||
value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
||||
if (warpId == 0) value = warpReduceMax(value);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
@@ -18,7 +18,6 @@
|
||||
#include "iomanip"
|
||||
#include <nvml.h>
|
||||
#include <iostream>
|
||||
#include <nvml.h>
|
||||
// #define PRINT_GPU_MEMORY
|
||||
// 函数用于获取 NVIDIA GPU 显存信息
|
||||
bool getNvidiaGPUMemoryUsage(int callLine) {
|
||||
|
@@ -30,10 +30,12 @@ paddle::Tensor mm(paddle::Tensor const& A, paddle::Tensor const& B,
|
||||
std::optional<paddle::Tensor> const& maybe_token_scales,
|
||||
std::string maybe_schedule) {
|
||||
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
|
||||
std::optional<int64_t> maybe_group_size_opt;
|
||||
std::optional<int64_t> maybe_group_size_opt = std::optional<int64_t>(maybe_group_size);
|
||||
std::optional<std::string> maybe_schedule_opt;
|
||||
if (maybe_schedule == "") {
|
||||
maybe_schedule_opt = std::nullopt;
|
||||
} else {
|
||||
maybe_schedule_opt = std::optional<std::string>(maybe_schedule);
|
||||
}
|
||||
return machete::mm_dispatch({.A = A,
|
||||
.B = B,
|
||||
@@ -63,6 +65,8 @@ std::vector<paddle::Tensor> MacheteMMKernel(
|
||||
paddle::DataType maybe_out_type;
|
||||
if (b_type_str == "uint4b8") {
|
||||
b_type_id = machete::kU4B8.id();
|
||||
} else if (b_type_str == "uint8b128") {
|
||||
b_type_id = machete::kU8B128.id();
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||
}
|
||||
|
@@ -51,6 +51,8 @@ std::vector<paddle::Tensor> MachetePrepackBKernel(
|
||||
|
||||
if (b_type_str == "uint4b8") {
|
||||
b_type_id = machete::kU4B8.id();
|
||||
} else if (b_type_str == "uint8b128") {
|
||||
b_type_id = machete::kU8B128.id();
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||
}
|
||||
|
@@ -70,7 +70,6 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
@@ -78,9 +77,8 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
@@ -97,14 +95,12 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const auto q_head_num = meta_data.q_num_heads;
|
||||
const auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
|
||||
const auto max_block_num = bsz * max_block_num_per_seq;
|
||||
const uint32_t chunk_size = get_max_partition_size(bsz);
|
||||
|
||||
|
||||
int q_head_dim = meta_data.head_dims;
|
||||
int k_head_dim = meta_data.head_dims;
|
||||
int v_head_dim = meta_data.head_dims_v;
|
||||
// int num_chunks = max_dec_len / chunk_size;
|
||||
int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
int num_chunks = div_up(max_seq_len, 64);
|
||||
|
||||
auto *allocator = paddle::GetAllocator(q.place());
|
||||
phi::Allocator::AllocationPtr O_tmp, m_tmp, d_tmp;
|
||||
@@ -127,14 +123,14 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
params.d = reinterpret_cast<float*>(d_tmp->ptr());
|
||||
params.block_tables = const_cast<int*>(block_tables.data<int>());
|
||||
params.seq_lens_this_time = const_cast<int*>(seq_lens_this_time.data<int>());
|
||||
params.seq_lens_encoder = const_cast<int*>(seq_lens_encoder.data<int>());
|
||||
params.seq_lens_decoder = const_cast<int*>(seq_lens_decoder.data<int>());
|
||||
params.cumsum_q_seqlens = const_cast<int*>(cu_seqlens_q.data<int>());
|
||||
params.batch_id_per_token = const_cast<int*>(batch_id_per_token.data<int>());
|
||||
params.batch_ids = const_cast<int*>(batch_ids.data<int>());
|
||||
params.tile_ids_per_batch = const_cast<int*>(tile_ids_per_batch.data<int>());
|
||||
params.num_blocks_x = const_cast<int*>(num_blocks_x_device.data<int>());
|
||||
params.num_blocks_x_int = num_blocks_x;
|
||||
params.chunk_size_device =
|
||||
const_cast<int*>(decoder_chunk_size_device.data<int>());
|
||||
params.q_stride_bsz = q_head_num * q_head_dim;
|
||||
params.q_stride_head_num = q_head_dim;
|
||||
params.kv_stride_block_num = block_size * k_head_dim;
|
||||
@@ -151,7 +147,6 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
params.block_size = block_size;
|
||||
params.max_draft_token_num = draft_token_num;
|
||||
params.sm_scale = softmax_scale;
|
||||
params.chunk_size = chunk_size;
|
||||
params.chunk_num = num_chunks;
|
||||
|
||||
if (q_head_dim == 576) {
|
||||
@@ -176,7 +171,6 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
@@ -184,9 +178,8 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
@@ -210,7 +203,6 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
@@ -218,9 +210,8 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
|
@@ -47,7 +47,6 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
@@ -55,9 +54,8 @@ void BatchMLAWithPagedKVCacheKernel(
|
||||
const paddle::Tensor& tile_ids_per_batch,
|
||||
const paddle::Tensor& num_blocks_x_device,
|
||||
const std::string& cache_quant_type_str,
|
||||
const int num_blocks_x,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const int max_seq_len,
|
||||
const int max_dec_len,
|
||||
const float softmax_scale,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound,
|
||||
|
@@ -128,12 +128,13 @@ struct CollectiveMainloop {
|
||||
DTypeMD const* d_ptr;
|
||||
IdType const* kv_block_tables;
|
||||
IdType const* seq_lens_this_time;
|
||||
IdType const* seq_lens_encoder;
|
||||
// IdType const* seq_lens_encoder;
|
||||
IdType const* seq_lens_decoder;
|
||||
IdType const* cumsum_q_seqlens;
|
||||
IdType const* batch_ids;
|
||||
IdType const* tile_ids_per_batch;
|
||||
IdType const* num_blocks_x;
|
||||
IdType const* chunk_size_device;
|
||||
float sm_scale;
|
||||
int bsz;
|
||||
int max_block_num;
|
||||
@@ -144,7 +145,7 @@ struct CollectiveMainloop {
|
||||
int kv_stride_block_size;
|
||||
int o_stride_bsz;
|
||||
int o_stride_head_num;
|
||||
int chunk_size;
|
||||
// int chunk_size;
|
||||
int chunk_num;
|
||||
int max_draft_token_num;
|
||||
};
|
||||
@@ -160,12 +161,13 @@ struct CollectiveMainloop {
|
||||
DTypeMD* d_ptr;
|
||||
IdType* kv_block_tables;
|
||||
IdType* seq_lens_this_time;
|
||||
IdType* seq_lens_encoder;
|
||||
// IdType* seq_lens_encoder;
|
||||
IdType* seq_lens_decoder;
|
||||
IdType* cumsum_q_seqlens;
|
||||
IdType* batch_ids;
|
||||
IdType* tile_ids_per_batch;
|
||||
IdType* num_blocks_x;
|
||||
IdType* chunk_size_device;
|
||||
float sm_scale;
|
||||
int bsz;
|
||||
int max_block_num;
|
||||
@@ -176,7 +178,7 @@ struct CollectiveMainloop {
|
||||
int kv_stride_block_size;
|
||||
int o_stride_bsz;
|
||||
int o_stride_head_num;
|
||||
int chunk_size;
|
||||
// int chunk_size;
|
||||
int chunk_num;
|
||||
int max_draft_token_num;
|
||||
TMA_KV tma_load_KV;
|
||||
@@ -198,12 +200,13 @@ struct CollectiveMainloop {
|
||||
const_cast<DTypeMD*>(args.d_ptr),
|
||||
const_cast<IdType*>(args.kv_block_tables),
|
||||
const_cast<IdType*>(args.seq_lens_this_time),
|
||||
const_cast<IdType*>(args.seq_lens_encoder),
|
||||
// const_cast<IdType*>(args.seq_lens_encoder),
|
||||
const_cast<IdType*>(args.seq_lens_decoder),
|
||||
const_cast<IdType*>(args.cumsum_q_seqlens),
|
||||
const_cast<IdType*>(args.batch_ids),
|
||||
const_cast<IdType*>(args.tile_ids_per_batch),
|
||||
const_cast<IdType*>(args.num_blocks_x),
|
||||
const_cast<IdType*>(args.chunk_size_device),
|
||||
args.sm_scale,
|
||||
args.bsz,
|
||||
args.max_block_num,
|
||||
@@ -214,7 +217,7 @@ struct CollectiveMainloop {
|
||||
args.kv_stride_block_size,
|
||||
args.o_stride_bsz,
|
||||
args.o_stride_head_num,
|
||||
args.chunk_size,
|
||||
// args.chunk_size,
|
||||
args.chunk_num,
|
||||
args.max_draft_token_num,
|
||||
tma_load_KV
|
||||
@@ -281,9 +284,9 @@ struct CollectiveMainloop {
|
||||
auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
|
||||
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
|
||||
|
||||
@@ -322,9 +325,9 @@ struct CollectiveMainloop {
|
||||
group_modes<0, 2>(sK), group_modes<0, 2>(gKV));
|
||||
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
|
||||
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
|
||||
|
||||
|
@@ -57,7 +57,7 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
|
||||
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
|
||||
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
|
||||
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size_device[0]);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
@@ -84,9 +84,9 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
|
||||
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
|
||||
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
|
||||
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
int kv_tile_idx = end_tile_idx;
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
@@ -263,7 +263,7 @@ CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
|
||||
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
|
||||
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
|
||||
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
|
||||
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size_device[0]);
|
||||
|
||||
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
|
||||
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
|
||||
@@ -295,9 +295,9 @@ CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
|
||||
Tensor tOrV4 = threadMmaPVSS.partition_fragment_B(sVt_s4);
|
||||
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
|
||||
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size;
|
||||
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
|
||||
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
|
||||
int kv_tile_idx = end_tile_idx;
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
|
@@ -62,13 +62,12 @@ struct Params {
|
||||
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
|
||||
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
|
||||
alignas(16) DTypeO *O_tmp; // [num_chunks, bsz, head_num, dim_head]
|
||||
alignas(16) float *m; // [num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) float *d; // [num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head]
|
||||
alignas(16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
||||
alignas(16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num]
|
||||
|
||||
alignas(16) IdType *block_tables;
|
||||
alignas(16) IdType *seq_lens_this_time;
|
||||
alignas(16) IdType *seq_lens_encoder;
|
||||
alignas(16) IdType *seq_lens_decoder;
|
||||
alignas(16) IdType *cumsum_q_seqlens;
|
||||
alignas(16) IdType *batch_id_per_token;
|
||||
@@ -76,7 +75,7 @@ struct Params {
|
||||
alignas(16) IdType *batch_ids;
|
||||
alignas(16) IdType *tile_ids_per_batch;
|
||||
alignas(16) IdType *num_blocks_x;
|
||||
|
||||
alignas(16) IdType *chunk_size_device;
|
||||
|
||||
uint32_t q_stride_bsz;
|
||||
uint32_t q_stride_head_num;
|
||||
@@ -96,9 +95,7 @@ struct Params {
|
||||
int vo_head_dim;
|
||||
int block_size;
|
||||
int max_draft_token_num;
|
||||
int chunk_size;
|
||||
int chunk_num;
|
||||
int num_blocks_x_int;
|
||||
|
||||
float sm_scale;
|
||||
};
|
||||
@@ -118,7 +115,7 @@ struct Params {
|
||||
return cudaErrorNotSupported; \
|
||||
}
|
||||
|
||||
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
||||
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
|
||||
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
typename CollectiveMainloop::Params const mainloop_params,
|
||||
@@ -137,6 +134,7 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
|
||||
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
|
||||
const int num_blocks_x = mainloop_params.num_blocks_x[0];
|
||||
const int chunk_size = mainloop_params.chunk_size_device[0];
|
||||
|
||||
static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV;
|
||||
|
||||
@@ -205,58 +203,10 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
|
||||
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
|
||||
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
// load Q
|
||||
collective_mainloop.load_q(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_write_q,
|
||||
shared_storage,
|
||||
threadIdx.x,
|
||||
bid);
|
||||
|
||||
if constexpr (!use_tma_load_kv) {
|
||||
// load kv
|
||||
collective_mainloop.load_kv(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
} else {
|
||||
if (warp_idx_in_warpgroup == 0) {
|
||||
// load kv tma
|
||||
collective_mainloop.load_kv_tma(
|
||||
mainloop_params,
|
||||
pipeline_kv,
|
||||
smem_pipe_write_kv,
|
||||
shared_storage,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
tile_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int block_id = blockIdx.x;
|
||||
const int bid = mainloop_params.batch_ids[block_id];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
@@ -309,76 +259,12 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
|
||||
|
||||
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
|
||||
|
||||
if constexpr (BLOCK_SHAPE_KV == 64) {
|
||||
mma_f16<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
} else if (BLOCK_SHAPE_KV == 32) {
|
||||
mma_f16_two_stages<Ktraits, CAUSAL>(
|
||||
mainloop_params,
|
||||
pipeline_q,
|
||||
smem_pipe_read_q,
|
||||
pipeline_kv,
|
||||
smem_pipe_read_kv,
|
||||
tOrO,
|
||||
attention_updater,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
seq_len_decoder_now,
|
||||
seq_len_now,
|
||||
tile_id,
|
||||
shared_storage);
|
||||
}
|
||||
|
||||
collective_epilogue.store(
|
||||
epilogue_params,
|
||||
tOrO,
|
||||
attention_updater.get_lse(),
|
||||
shared_storage,
|
||||
tiled_mma_pv,
|
||||
threadIdx.x - NUM_COPY_THREADS,
|
||||
bid,
|
||||
mainloop_params.bsz,
|
||||
seq_len_now,
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
mainloop_params.chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
} else {
|
||||
const int block_id = blockIdx.x;
|
||||
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
|
||||
clear(tOrO);
|
||||
clear(attention_updater.scores_scale);
|
||||
const int bid = mainloop_params.batch_ids[block_id];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
|
||||
const int bid = mainloop_params.batch_ids[i];
|
||||
const int tile_id = mainloop_params.tile_ids_per_batch[i];
|
||||
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
|
||||
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
|
||||
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
|
||||
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
|
||||
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
|
||||
@@ -429,15 +315,15 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
|
||||
start_token_idx,
|
||||
tile_id,
|
||||
seq_len_decoder_now,
|
||||
mainloop_params.chunk_size,
|
||||
chunk_size,
|
||||
mainloop_params.max_draft_token_num,
|
||||
mainloop_params.o_stride_bsz);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
||||
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
cudaStream_t stream) {
|
||||
using DTypeQ = typename KernelTraits::DTypeQ;
|
||||
@@ -460,12 +346,12 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
params.d,
|
||||
params.block_tables,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_encoder,
|
||||
params.seq_lens_decoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_ids,
|
||||
params.tile_ids_per_batch,
|
||||
params.num_blocks_x,
|
||||
params.chunk_size_device,
|
||||
params.sm_scale,
|
||||
params.bsz,
|
||||
params.max_block_num,
|
||||
@@ -476,7 +362,6 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
params.kv_stride_block_size,
|
||||
params.o_stride_bsz,
|
||||
params.o_stride_head_num,
|
||||
params.chunk_size,
|
||||
params.chunk_num,
|
||||
params.max_draft_token_num
|
||||
});
|
||||
@@ -500,13 +385,9 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
|
||||
|
||||
int gridx;
|
||||
if constexpr(USE_FIXED_BLOCK) {
|
||||
gridx = multiprocessor_count;
|
||||
} else {
|
||||
gridx = params.num_blocks_x_int;
|
||||
}
|
||||
dim3 grid_dims = {gridx, 1, 1};
|
||||
// NOTE: (changwenbin) Here the grid size is fixed so that MLA can be captured
|
||||
// by the graph.
|
||||
dim3 grid_dims = {multiprocessor_count, 1, 1};
|
||||
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
|
||||
dim3 block_dims(ctaSize, 1, 1);
|
||||
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
|
||||
@@ -517,37 +398,38 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
||||
constexpr int merge_block_size = 256;
|
||||
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
|
||||
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
|
||||
dim3 grids_merge(min(multiprocessor_count, params.token_num), params.q_num_head); // 128k is too large
|
||||
dim3 grids_merge(multiprocessor_count, params.q_num_head); // 128k is too large
|
||||
dim3 blocks_merge(blockx, blocky);
|
||||
merge_multi_chunks_kernel<NV_TYPE, vec_size, blocky, KernelTraits::HEAD_DIM_VO><<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE*>(params.O_tmp),
|
||||
params.m,
|
||||
params.d,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_decoder,
|
||||
params.seq_lens_encoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_id_per_token,
|
||||
reinterpret_cast<NV_TYPE*>(params.O),
|
||||
params.chunk_num,
|
||||
params.q_num_head,
|
||||
params.chunk_size,
|
||||
params.vo_head_dim,
|
||||
params.token_num,
|
||||
params.bsz,
|
||||
params.max_draft_token_num
|
||||
);
|
||||
merge_multi_chunks_kernel<NV_TYPE,
|
||||
vec_size,
|
||||
blocky,
|
||||
KernelTraits::HEAD_DIM_VO>
|
||||
<<<grids_merge, blocks_merge, 0, stream>>>(
|
||||
reinterpret_cast<NV_TYPE *>(params.O_tmp),
|
||||
params.m,
|
||||
params.d,
|
||||
params.seq_lens_this_time,
|
||||
params.seq_lens_decoder,
|
||||
params.cumsum_q_seqlens,
|
||||
params.batch_id_per_token,
|
||||
params.chunk_size_device,
|
||||
reinterpret_cast<NV_TYPE *>(params.O),
|
||||
params.q_num_head,
|
||||
params.vo_head_dim,
|
||||
params.token_num,
|
||||
params.bsz,
|
||||
params.max_draft_token_num);
|
||||
}
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
|
||||
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
|
||||
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
|
||||
constexpr bool CAUSAL = true;
|
||||
if constexpr (HEAD_DIM_QK == 576) {
|
||||
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
|
||||
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
|
||||
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/false,
|
||||
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true,
|
||||
HEAD_DIM_QK,
|
||||
HEAD_DIM_VO,
|
||||
GROUP_SIZE,
|
||||
|
@@ -249,18 +249,16 @@ struct prefill_softmax_state_t {
|
||||
};
|
||||
|
||||
template <typename T, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
|
||||
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [num_chunks, bsz, max_draft_token, num_heads, head_dim]
|
||||
const float * __restrict__ multi_m, // [num_chunks, bsz, max_draft_token, num_heads]
|
||||
const float * __restrict__ multi_d, // [num_chunks, bsz, max_draft_token, num_heads]
|
||||
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [max_num_chunks, bsz, max_draft_token, num_heads, head_dim]
|
||||
const float * __restrict__ multi_m, // [max_num_chunks, bsz, max_draft_token, num_heads]
|
||||
const float * __restrict__ multi_d, // [max_num_chunks, bsz, max_draft_token, num_heads]
|
||||
const int * __restrict__ seq_lens_this_time,
|
||||
const int * __restrict__ seq_lens_decoder,
|
||||
const int * __restrict__ seq_lens_encoder,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int * __restrict__ batch_id_per_token,
|
||||
const int * __restrict__ chunk_size_device,
|
||||
T * __restrict__ out, // [token_num, num_heads, head_dim]
|
||||
const int num_chunks,
|
||||
const int num_heads,
|
||||
const int chunk_size,
|
||||
const int head_dim,
|
||||
const int token_num,
|
||||
const int bsz,
|
||||
@@ -271,13 +269,15 @@ __global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
||||
const uint32_t bid = batch_id_per_token[qid];
|
||||
// NOTE : (changwenbin) Batch_id_per_token is initialized to [:]=-1, Marking meaningless batch IDs.
|
||||
if (bid == -1) continue;
|
||||
const int seq_len_q = seq_lens_this_time[bid];
|
||||
if (seq_len_q == 0) continue;
|
||||
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
|
||||
int seq_len_kv = seq_lens_decoder[bid];
|
||||
if (seq_len_kv == 0) continue;
|
||||
seq_len_kv += seq_len_q;
|
||||
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size);
|
||||
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size_device[0]);
|
||||
if (num_chunks_this_seq <= 1) {
|
||||
// not need merge
|
||||
continue;
|
||||
|
@@ -33,6 +33,11 @@
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 3: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 3; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 6: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 6; \
|
||||
__VA_ARGS__ \
|
||||
@@ -448,137 +453,71 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
||||
auto place = input.place();
|
||||
const int gridx = min(132 * 8, num_rows);
|
||||
if (moe_quant_type == "w4a8") {
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_kernel<data_t, int8_t, 8><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<int8_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);
|
||||
} else if (num_experts_per_rank == 16) {
|
||||
permute_x_kernel<data_t, int8_t, 16><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<int8_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);
|
||||
}
|
||||
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
|
||||
permute_x_kernel<data_t, int8_t, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<int8_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);)
|
||||
} else if (moe_quant_type == "w4afp8") {
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_kernel<data_t, data_t_fp8, 8, 512><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t_fp8>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
448.0f,
|
||||
-448.0f
|
||||
);
|
||||
} else if (num_experts_per_rank == 16) {
|
||||
permute_x_kernel<data_t, data_t_fp8, 16, 512><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t_fp8>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
448.0f,
|
||||
-448.0f
|
||||
);
|
||||
}
|
||||
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
|
||||
permute_x_kernel<data_t, data_t_fp8, NUM_EXPERTS_PER_RANK, 512><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t_fp8>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
448.0f,
|
||||
-448.0f
|
||||
);)
|
||||
} else {
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_kernel<data_t, data_t, 8><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);
|
||||
} else if (num_experts_per_rank == 16) {
|
||||
permute_x_kernel<data_t, data_t, 16><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);
|
||||
}
|
||||
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
|
||||
permute_x_kernel<data_t, data_t, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);)
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -236,7 +236,7 @@ public:
|
||||
num_experts, k, stream);
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float, int>::run(
|
||||
topk_gating_softmax_kernelLauncher<float, int>(
|
||||
gating_output, nullptr, expert_scales_float, softmax_out_,
|
||||
expert_for_source_row, source_rows_, softmax_max_prob, num_rows,
|
||||
num_experts, k, group_moe, stream);
|
||||
@@ -248,7 +248,7 @@ public:
|
||||
permuted_experts_, source_rows_, permuted_rows_, k * num_rows,
|
||||
false, stream);
|
||||
|
||||
initialize_moe_routing_kernelLauncher<T>::run(
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input_activations, permuted_data_, permuted_rows_, nullptr, nullptr,
|
||||
expanded_source_row_to_expanded_dest_row, num_rows, num_rows,
|
||||
hidden_size, k, stream);
|
||||
@@ -335,14 +335,14 @@ public:
|
||||
num_experts, down_proj_quant_args, stream);
|
||||
}
|
||||
|
||||
finalize_moe_routing_kernelLauncher<T>::run(
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
fc2_result, output_, fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row, expert_for_source_row,
|
||||
num_rows, hidden_size, k, static_cast<int>(1), norm_topk_prob,
|
||||
routed_scaling_factor, stream);
|
||||
} else {
|
||||
finalize_moe_routing_kernelLauncher<T>::run(
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
// fc2_result,
|
||||
fc1_out, output_,
|
||||
fc1_expert_biases, // fc2_expert_biases,
|
||||
|
@@ -1139,9 +1139,7 @@ void topk_gating_softmax_launcher_helper(const T* input,
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = int>
|
||||
struct topk_gating_softmax_kernelLauncher{
|
||||
|
||||
static void run(const T* input,
|
||||
void topk_gating_softmax_kernelLauncher(const T* input,
|
||||
const T* gating_correction_bias,
|
||||
T* output,
|
||||
T* softmax,
|
||||
@@ -1221,7 +1219,6 @@ static void run(const T* input,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ========================== Permutation things
|
||||
// =======================================
|
||||
@@ -1316,9 +1313,7 @@ __global__ void initialize_moe_routing_kernel(
|
||||
}
|
||||
|
||||
template <typename T, typename OutT = T>
|
||||
struct initialize_moe_routing_kernelLauncher{
|
||||
|
||||
static void run(
|
||||
void initialize_moe_routing_kernelLauncher(
|
||||
const T* unpermuted_input,
|
||||
OutT* permuted_output,
|
||||
const int* expanded_dest_row_to_expanded_source_row,
|
||||
@@ -1361,7 +1356,6 @@ static void run(
|
||||
num_rows * k);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ============================== Infer GEMM sizes
|
||||
// =================================
|
||||
@@ -1472,8 +1466,7 @@ __global__ void finalize_moe_routing_kernel(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct finalize_moe_routing_kernelLauncher{
|
||||
static void run(
|
||||
void finalize_moe_routing_kernelLauncher(
|
||||
const T* expanded_permuted_rows,
|
||||
T* reduced_unpermuted_output,
|
||||
const T* bias,
|
||||
@@ -1505,5 +1498,4 @@ static void run(
|
||||
routed_scaling_factor,
|
||||
num_rows);
|
||||
}
|
||||
};
|
||||
} // namespace phi
|
||||
|
@@ -36,6 +36,9 @@ void MoeDispatchKernel(
|
||||
paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token) {
|
||||
using namespace phi;
|
||||
|
||||
if (num_rows == 0){
|
||||
return;
|
||||
}
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
@@ -100,7 +103,7 @@ void MoeDispatchKernel(
|
||||
softmax_out_ = nullptr;
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float, int>::run(
|
||||
topk_gating_softmax_kernelLauncher(
|
||||
gating_output.data<float>(),
|
||||
gating_correction_bias ? gating_correction_bias.get().data<float>()
|
||||
: nullptr,
|
||||
@@ -114,13 +117,13 @@ void MoeDispatchKernel(
|
||||
|
||||
if (w4a8_in_scale) {
|
||||
if (permute_input->dtype() == paddle::DataType::INT8) {
|
||||
initialize_moe_routing_kernelLauncher<data_t, int8_t>::run(
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(), permute_input->data<int8_t>(), permuted_rows_,
|
||||
expert_idx_per_token->data<int32_t>(), w4a8_in_scale->data<float>(),
|
||||
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
|
||||
hidden_size, moe_topk, stream);
|
||||
} else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) {
|
||||
initialize_moe_routing_kernelLauncher<data_t, float8_e4m3fn>::run(
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(), permute_input->data<float8_e4m3fn>(),
|
||||
permuted_rows_, expert_idx_per_token->data<int32_t>(),
|
||||
w4a8_in_scale->data<float>(),
|
||||
@@ -128,7 +131,7 @@ void MoeDispatchKernel(
|
||||
hidden_size, moe_topk, stream);
|
||||
}
|
||||
} else {
|
||||
initialize_moe_routing_kernelLauncher<data_t>::run(
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(), permute_input->data<data_t>(), permuted_rows_,
|
||||
expert_idx_per_token->data<int32_t>(), nullptr,
|
||||
permute_indices_per_token->data<int32_t>(), num_rows, num_rows,
|
||||
@@ -185,6 +188,15 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
auto expert_idx_per_token =
|
||||
GetEmptyTensor({num_rows * moe_topk}, paddle::DataType::INT32, place);
|
||||
|
||||
if (token_rows == 0){
|
||||
return {permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
permute_indices_per_token,
|
||||
topk_weight,
|
||||
topk_idx,
|
||||
expert_idx_per_token};
|
||||
}
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeDispatchKernel<paddle::DataType::BFLOAT16>(
|
||||
|
@@ -412,7 +412,9 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype()
|
||||
(quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 :
|
||||
permute_input.dtype();
|
||||
auto ffn_out = paddle::empty_like(permute_input, t_type);
|
||||
|
||||
if(permute_input.numel() == 0){
|
||||
return ffn_out;
|
||||
}
|
||||
switch (t_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeFFNKernel<paddle::DataType::BFLOAT16>(permute_input,
|
||||
|
@@ -36,7 +36,7 @@ void MoeReduceKernel(const paddle::Tensor &ffn_out,
|
||||
typedef typename traits_::data_t data_t;
|
||||
auto stream = ffn_out.stream();
|
||||
|
||||
finalize_moe_routing_kernelLauncher<data_t>::run(
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
ffn_out.data<data_t>(), output->data<data_t>(),
|
||||
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
|
||||
top_k_weight.data<float>(), permute_indices_per_token.data<int32_t>(),
|
||||
@@ -59,6 +59,10 @@ paddle::Tensor MoeExpertReduceFunc(
|
||||
|
||||
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
|
||||
|
||||
if(num_rows == 0){
|
||||
return output;
|
||||
}
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeReduceKernel<paddle::DataType::BFLOAT16>(
|
||||
|
@@ -22,23 +22,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
@@ -64,9 +59,12 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
int decoder_num_blocks_data = decoder_num_blocks_cpu.data<int>()[0];
|
||||
// NOTE: (changwenbin) In cuda graph, it will be fixed in the capture stage
|
||||
// int decoder_num_blocks_data = decoder_num_blocks_cpu.data<int>()[0];
|
||||
int max_dec_len_this_time_data = max_dec_len_this_time.data<int>()[0];
|
||||
int max_len_kv_data = max_len_kv.data<int>()[0];
|
||||
// int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
|
||||
//
|
||||
|
||||
const bool mla_use_tensorcore = get_mla_use_tensorcore();
|
||||
auto sm_version = GetSMVersion();
|
||||
@@ -96,7 +94,6 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
out_linear_smooths,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
cu_seqlens_q,
|
||||
batch_id_per_token,
|
||||
block_tables,
|
||||
@@ -104,9 +101,8 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
cache_quant_type_str,
|
||||
decoder_num_blocks_data,
|
||||
decoder_chunk_size_device,
|
||||
max_input_length,
|
||||
max_len_kv_data,
|
||||
softmax_scale,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
@@ -145,23 +141,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
const paddle::Tensor& query,
|
||||
const paddle::Tensor& key_cache,
|
||||
const paddle::Tensor& value_cache,
|
||||
const paddle::Tensor& seq_lens_encoder,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& seq_lens_this_time,
|
||||
const paddle::Tensor& cu_seqlens_q,
|
||||
const paddle::Tensor& batch_id_per_token,
|
||||
const paddle::Tensor& block_tables,
|
||||
const paddle::Tensor& encoder_batch_ids,
|
||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& encoder_num_blocks,
|
||||
const paddle::Tensor& kv_batch_ids,
|
||||
const paddle::Tensor& kv_tile_ids_per_batch,
|
||||
const paddle::Tensor& kv_num_blocks,
|
||||
const paddle::Tensor& decoder_batch_ids,
|
||||
const paddle::Tensor& decoder_tile_ids_per_batch,
|
||||
const paddle::Tensor& decoder_num_blocks,
|
||||
const paddle::Tensor& decoder_num_blocks_cpu,
|
||||
const paddle::Tensor& max_enc_len_this_time,
|
||||
const paddle::Tensor& decoder_chunk_size_device,
|
||||
const paddle::Tensor& max_dec_len_this_time,
|
||||
const paddle::Tensor& max_len_kv,
|
||||
const paddle::optional<paddle::Tensor>& attn_mask,
|
||||
@@ -208,23 +199,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
cu_seqlens_q,
|
||||
batch_id_per_token,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
decoder_num_blocks_cpu,
|
||||
max_enc_len_this_time,
|
||||
decoder_chunk_size_device,
|
||||
max_dec_len_this_time,
|
||||
max_len_kv,
|
||||
attn_mask,
|
||||
@@ -254,23 +240,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
cu_seqlens_q,
|
||||
batch_id_per_token,
|
||||
block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
decoder_num_blocks_cpu,
|
||||
max_enc_len_this_time,
|
||||
decoder_chunk_size_device,
|
||||
max_dec_len_this_time,
|
||||
max_len_kv,
|
||||
attn_mask,
|
||||
@@ -307,23 +288,18 @@ std::vector<std::vector<int64_t>> MultiHeadLatentAttentionInferShape(
|
||||
const std::vector<int64_t>& query_shape,
|
||||
const std::vector<int64_t>& key_cache_shape,
|
||||
const std::vector<int64_t>& value_cache_shape,
|
||||
const std::vector<int64_t>& seq_lens_encoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||
const std::vector<int64_t>& batch_id_per_token_shape,
|
||||
const std::vector<int64_t>& block_tables_shape,
|
||||
const std::vector<int64_t>& encoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& encoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& kv_batch_ids_shape,
|
||||
const std::vector<int64_t>& kv_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& kv_num_blocks_shape,
|
||||
const std::vector<int64_t>& decoder_batch_ids_shape,
|
||||
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_shape,
|
||||
const std::vector<int64_t>& decoder_num_blocks_cpu_shape,
|
||||
const std::vector<int64_t>& max_enc_len_this_time_shape,
|
||||
const std::vector<int64_t>& decoder_chunk_size_device_shape,
|
||||
const std::vector<int64_t>& max_dec_len_this_time_shape,
|
||||
const std::vector<int64_t>& max_len_kv_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
|
||||
@@ -361,23 +337,18 @@ std::vector<paddle::DataType> MultiHeadLatentAttentionInferDtype(
|
||||
const paddle::DataType& query_dtype,
|
||||
const paddle::DataType& key_cache_dtype,
|
||||
const paddle::DataType& value_cache_dtype,
|
||||
const paddle::DataType& seq_lens_encoder_dtype,
|
||||
const paddle::DataType& seq_lens_decoder_dtype,
|
||||
const paddle::DataType& seq_lens_this_time_dtype,
|
||||
const paddle::DataType& cu_seqlens_q_dtype,
|
||||
const paddle::DataType& batch_id_per_token_dtype,
|
||||
const paddle::DataType& block_tables_dtype,
|
||||
const paddle::DataType& encoder_batch_ids_dtype,
|
||||
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& encoder_num_blocks_dtype,
|
||||
const paddle::DataType& kv_batch_ids_dtype,
|
||||
const paddle::DataType& kv_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& kv_num_blocks_dtype,
|
||||
const paddle::DataType& decoder_batch_ids_dtype,
|
||||
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_dtype,
|
||||
const paddle::DataType& decoder_num_blocks_cpu_dtype,
|
||||
const paddle::DataType& max_enc_len_this_time_dtype,
|
||||
const paddle::DataType& decoder_chunk_size_device_dtype,
|
||||
const paddle::DataType& max_dec_len_this_time_dtype,
|
||||
const paddle::DataType& max_len_kv_dtype,
|
||||
const paddle::optional<paddle::DataType>& attn_mask_dtype,
|
||||
@@ -415,23 +386,18 @@ PD_BUILD_STATIC_OP(multi_head_latent_attention)
|
||||
.Inputs({"query",
|
||||
"key_cache",
|
||||
"value_cache",
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"cu_seqlens_q",
|
||||
"batch_id_per_token",
|
||||
"block_tables",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
"encoder_num_blocks",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
"kv_num_blocks",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks",
|
||||
"decoder_num_blocks_cpu",
|
||||
"max_enc_len_this_time",
|
||||
"decoder_chunk_size_device",
|
||||
"max_dec_len_this_time",
|
||||
"max_len_kv",
|
||||
paddle::Optional("attn_mask"),
|
||||
|
@@ -26,6 +26,7 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
bool renormalize,
|
||||
float routed_scaling_factor) {
|
||||
auto input_shape = scores_with_bias.shape();
|
||||
PD_CHECK(input_shape.size() == 2);
|
||||
@@ -48,6 +49,7 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
|
||||
@@ -76,6 +78,7 @@ PD_BUILD_STATIC_OP(noaux_tc)
|
||||
.Attrs({"n_group: int",
|
||||
"topk_group: int",
|
||||
"topk:int",
|
||||
"renormalize: bool",
|
||||
"routed_scaling_factor: float"})
|
||||
.SetKernelFn(PD_KERNEL(NoauxTc))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcInferShape))
|
||||
|
@@ -25,6 +25,23 @@ constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||
constexpr int32_t BLOCK_SIZE = 512;
|
||||
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
__device__ inline T_OUT cuda_cast(T_IN val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T neg_inf() {
|
||||
// cuda::std::numeric_limits<T>::infinity() returns `0` for [T=bf16 or fp16]
|
||||
// so we need to cast from fp32
|
||||
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
|
||||
}
|
||||
|
||||
namespace warp_topk {
|
||||
|
||||
template <int size, typename T>
|
||||
@@ -41,10 +58,21 @@ constexpr __host__ __device__ bool isPowerOf2(T v) {
|
||||
}
|
||||
|
||||
template <bool greater, typename T>
|
||||
__device__ bool is_better_than(T val, T baseline) {
|
||||
__forceinline__ __device__ bool is_better_than(T val, T baseline) {
|
||||
return (val > baseline && greater) || (val < baseline && !greater);
|
||||
}
|
||||
|
||||
template <bool greater, typename T, typename idxT>
|
||||
__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index,
|
||||
idxT baseline_index) {
|
||||
bool res = (val > baseline && greater) || (val < baseline && !greater);
|
||||
if (val == baseline) {
|
||||
res = (index < baseline_index && greater) ||
|
||||
(index < baseline_index && !greater);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T, typename idxT>
|
||||
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
||||
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
|
||||
@@ -53,7 +81,8 @@ int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
||||
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
|
||||
}
|
||||
|
||||
template <int size, bool ascending, typename T, typename idxT>
|
||||
template <int size, bool ascending, bool reverse, typename T, typename idxT,
|
||||
bool is_stable>
|
||||
struct BitonicMerge {
|
||||
// input should be a bitonic sequence, and sort it to be a monotonic sequence
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
@@ -67,7 +96,15 @@ struct BitonicMerge {
|
||||
int const other_i = i + stride;
|
||||
T& val = val_arr[i];
|
||||
T& other_val = val_arr[other_i];
|
||||
if ((val > other_val && ascending) || (val < other_val && !ascending)) {
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better = is_better_than<ascending>(val, other_val, idx_arr[i],
|
||||
idx_arr[other_i]);
|
||||
} else {
|
||||
is_better = is_better_than<ascending>(val, other_val);
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
T tmp = val;
|
||||
val = other_val;
|
||||
other_val = tmp;
|
||||
@@ -78,13 +115,14 @@ struct BitonicMerge {
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr + arr_len / 2,
|
||||
idx_arr + arr_len / 2);
|
||||
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
|
||||
val_arr, idx_arr);
|
||||
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
|
||||
val_arr + arr_len / 2, idx_arr + arr_len / 2);
|
||||
}
|
||||
};
|
||||
|
||||
template <int size, bool ascending, typename T, typename idxT>
|
||||
template <int size, bool ascending, typename T, typename idxT, bool is_stable>
|
||||
struct BitonicSort {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
@@ -92,15 +130,16 @@ struct BitonicSort {
|
||||
static_assert(size >= 2 * WARP_SIZE);
|
||||
constexpr int arr_len = size / WARP_SIZE;
|
||||
|
||||
BitonicSort<size / 2, true, T, idxT>::sort(val_arr, idx_arr);
|
||||
BitonicSort<size / 2, false, T, idxT>::sort(val_arr + arr_len / 2,
|
||||
idx_arr + arr_len / 2);
|
||||
BitonicMerge<size, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
BitonicSort<size / 2, true, T, idxT, is_stable>::sort(val_arr, idx_arr);
|
||||
BitonicSort<size / 2, false, T, idxT, is_stable>::sort(
|
||||
val_arr + arr_len / 2, idx_arr + arr_len / 2);
|
||||
BitonicMerge<size, ascending, ascending, T, idxT, is_stable>::merge(
|
||||
val_arr, idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, typename T, typename idxT>
|
||||
struct BitonicSort<32, ascending, T, idxT> {
|
||||
template <bool ascending, typename T, typename idxT, bool is_stable>
|
||||
struct BitonicSort<32, ascending, T, idxT, is_stable> {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
@@ -114,19 +153,37 @@ struct BitonicSort<32, ascending, T, idxT> {
|
||||
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);
|
||||
if (*val_arr != other && (*val_arr > other) != (reverse != is_second)) {
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
if constexpr (ascending) {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr < other_idx))) !=
|
||||
(reverse != is_second);
|
||||
} else {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr > other_idx))) !=
|
||||
(reverse != is_second);
|
||||
}
|
||||
} else {
|
||||
is_better = (*val_arr != other &&
|
||||
(*val_arr > other) != (reverse != is_second));
|
||||
}
|
||||
if (is_better) {
|
||||
*val_arr = other;
|
||||
*idx_arr = other_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<32, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr,
|
||||
idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, typename T, typename idxT>
|
||||
struct BitonicMerge<32, ascending, T, idxT> {
|
||||
template <bool ascending, bool reverse, typename T, typename idxT,
|
||||
bool is_stable>
|
||||
struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> {
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
@@ -136,7 +193,24 @@ struct BitonicMerge<32, ascending, T, idxT> {
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
|
||||
idxT& idx = *idx_arr;
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);
|
||||
if (val != other && ((val > other) == (ascending != is_second))) {
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
if constexpr (ascending) {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr < other_idx))) ==
|
||||
(reverse != is_second); // for min
|
||||
} else {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr > other_idx))) ==
|
||||
(reverse != is_second); // for max
|
||||
}
|
||||
} else {
|
||||
is_better =
|
||||
(val != other && ((val > other) == (ascending != is_second)));
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
val = other;
|
||||
idx = other_idx;
|
||||
}
|
||||
@@ -144,34 +218,42 @@ struct BitonicMerge<32, ascending, T, idxT> {
|
||||
}
|
||||
};
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT>
|
||||
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
|
||||
class WarpSort {
|
||||
public:
|
||||
public:
|
||||
__device__ WarpSort(idxT k, T dummy)
|
||||
: lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
|
||||
static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));
|
||||
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
val_arr_[i] = dummy_;
|
||||
idx_arr_[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// load and merge k sorted values
|
||||
__device__ void load_sorted(T const* __restrict__ in,
|
||||
idxT const* __restrict__ in_idx,
|
||||
idxT start) {
|
||||
idxT const* __restrict__ in_idx, idxT start) {
|
||||
idxT idx = start + WARP_SIZE - 1 - lane_;
|
||||
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
|
||||
if (idx < start + k_) {
|
||||
T t = in[idx];
|
||||
if (is_better_than<greater>(t, val_arr_[i])) {
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better =
|
||||
is_better_than<greater>(t, val_arr_[i], in_idx[idx], idx_arr_[i]);
|
||||
} else {
|
||||
is_better = is_better_than<greater>(t, val_arr_[i]);
|
||||
}
|
||||
if (is_better) {
|
||||
val_arr_[i] = t;
|
||||
idx_arr_[i] = in_idx[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
|
||||
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
|
||||
val_arr_, idx_arr_);
|
||||
}
|
||||
|
||||
__device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
|
||||
@@ -193,7 +275,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
protected:
|
||||
static constexpr int max_arr_len_ = capacity / WARP_SIZE;
|
||||
|
||||
T val_arr_[max_arr_len_];
|
||||
@@ -205,11 +287,11 @@ protected:
|
||||
|
||||
}; // end class WarpSort
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT>
|
||||
class WarpSelect : public WarpSort<capacity, greater, T, idxT> {
|
||||
public:
|
||||
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
|
||||
class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
|
||||
public:
|
||||
__device__ WarpSelect(idxT k, T dummy)
|
||||
: WarpSort<capacity, greater, T, idxT>(k, dummy),
|
||||
: WarpSort<capacity, greater, T, idxT, is_stable>(k, dummy),
|
||||
k_th_(dummy),
|
||||
k_th_lane_((k - 1) % WARP_SIZE) {
|
||||
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
|
||||
@@ -234,7 +316,13 @@ public:
|
||||
}
|
||||
|
||||
__device__ void add(T val, idxT idx) {
|
||||
bool do_add = is_better_than<greater>(val, k_th_);
|
||||
bool do_add;
|
||||
if constexpr (is_stable) {
|
||||
do_add = is_better_than<greater>(val, k_th_, idx, k_th_idx_);
|
||||
} else {
|
||||
do_add = is_better_than<greater>(val, k_th_);
|
||||
}
|
||||
|
||||
uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
|
||||
if (mask == 0) {
|
||||
return;
|
||||
@@ -271,37 +359,52 @@ public:
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
__device__ void set_k_th_() {
|
||||
k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||
if constexpr (is_stable) {
|
||||
k_th_idx_ =
|
||||
__shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void merge_buf_(T val, idxT idx) {
|
||||
BitonicSort<WARP_SIZE, greater, T, idxT>::sort(&val, &idx);
|
||||
BitonicSort<WARP_SIZE, greater, T, idxT, is_stable>::sort(&val, &idx);
|
||||
|
||||
T& old = val_arr_[max_arr_len_ - 1];
|
||||
if (is_better_than<greater>(val, old)) {
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better =
|
||||
is_better_than<greater>(val, old, idx, idx_arr_[max_arr_len_ - 1]);
|
||||
} else {
|
||||
is_better = is_better_than<greater>(val, old);
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
old = val;
|
||||
idx_arr_[max_arr_len_ - 1] = idx;
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
|
||||
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
|
||||
val_arr_, idx_arr_);
|
||||
|
||||
set_k_th_();
|
||||
}
|
||||
|
||||
using WarpSort<capacity, greater, T, idxT>::max_arr_len_;
|
||||
using WarpSort<capacity, greater, T, idxT>::val_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT>::idx_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT>::lane_;
|
||||
using WarpSort<capacity, greater, T, idxT>::k_;
|
||||
using WarpSort<capacity, greater, T, idxT>::dummy_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::max_arr_len_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::val_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::idx_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::lane_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::k_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::dummy_;
|
||||
|
||||
T* val_smem_;
|
||||
idxT* idx_smem_;
|
||||
int smem_buf_len_ = 0;
|
||||
|
||||
T k_th_;
|
||||
idxT k_th_idx_;
|
||||
int const k_th_lane_;
|
||||
}; // end class WarpSelect
|
||||
} // namespace warp_topk
|
||||
@@ -313,8 +416,8 @@ __device__ void topk_with_k2(T* output,
|
||||
int32_t const lane_id,
|
||||
int const num_experts_per_group) {
|
||||
// Get the top2 per thread
|
||||
T largest = cuda::std::numeric_limits<T>::min();
|
||||
T second_largest = cuda::std::numeric_limits<T>::min();
|
||||
T largest = neg_inf<T>();
|
||||
T second_largest = neg_inf<T>();
|
||||
|
||||
if (num_experts_per_group > WARP_SIZE) {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
@@ -368,8 +471,14 @@ __global__ void topk_with_k2_kernel(T* output,
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
@@ -385,6 +494,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
int64_t const topk,
|
||||
int64_t const num_experts,
|
||||
int64_t const num_experts_per_group,
|
||||
bool const renormalize,
|
||||
double routed_scaling_factor) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
@@ -403,19 +513,29 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
|
||||
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
|
||||
// store the target topk idx
|
||||
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf) + warp_id * topk;
|
||||
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf);
|
||||
T* s_topk_value =
|
||||
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
|
||||
warp_id * topk;
|
||||
s_topk_idx += warp_id * topk;
|
||||
|
||||
T value = cuda::std::numeric_limits<T>::min();
|
||||
T topk_group_value = cuda::std::numeric_limits<T>::min();
|
||||
T value = neg_inf<T>();
|
||||
T topk_group_value = neg_inf<T>();
|
||||
int32_t num_equalto_topkth_group;
|
||||
|
||||
if ((n_group > topk_group) && (case_id < num_tokens)) {
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before
|
||||
// acqbulk because it's ptr arithmetic
|
||||
#endif
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
// calculate group_idx
|
||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
||||
if (lane_id < n_group) {
|
||||
if (lane_id < n_group &&
|
||||
(isfinite(cuda_cast<float, T>(
|
||||
group_scores[lane_id])))) // The check is necessary to avoid
|
||||
// abnormal input
|
||||
{
|
||||
value = group_scores[lane_id];
|
||||
}
|
||||
|
||||
@@ -426,22 +546,23 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = cuda::std::numeric_limits<T>::min();
|
||||
value = neg_inf<T>();
|
||||
}
|
||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||
count_equal_to_top_value = __popc(__ballot_sync(
|
||||
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min())));
|
||||
FULL_WARP_MASK, (value == neg_inf<T>())));
|
||||
}
|
||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t>
|
||||
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
|
||||
/* is_stable */ true>
|
||||
queue((int32_t)topk, neg_inf<T>());
|
||||
|
||||
int count_equalto_topkth_group = 0;
|
||||
bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits<T>::min());
|
||||
if (case_id < num_tokens) {
|
||||
bool if_proceed_next_topk = (topk_group_value != neg_inf<T>());
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
||||
if ((group_scores[i_group] > topk_group_value) ||
|
||||
((group_scores[i_group] == topk_group_value) &&
|
||||
@@ -449,9 +570,11 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
int32_t offset = i_group * num_experts_per_group;
|
||||
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
||||
i += WARP_SIZE) {
|
||||
T candidates = i < num_experts_per_group
|
||||
? scores_with_bias[offset + i]
|
||||
: cuda::std::numeric_limits<T>::min();
|
||||
T candidates =
|
||||
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
|
||||
scores_with_bias[offset + i]))
|
||||
? scores_with_bias[offset + i]
|
||||
: neg_inf<T>();
|
||||
queue.add(candidates, offset + i);
|
||||
}
|
||||
if (group_scores[i_group] == topk_group_value) {
|
||||
@@ -469,7 +592,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
// Load the valid score value
|
||||
// Calculate the summation
|
||||
float topk_sum = 1e-20;
|
||||
if (case_id < num_tokens) {
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i = lane_id;
|
||||
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
|
||||
i += WARP_SIZE) {
|
||||
@@ -478,33 +601,45 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
if (i < topk) {
|
||||
s_topk_value[i] = value;
|
||||
}
|
||||
topk_sum += reduce(tile, value, cg::plus<float>());
|
||||
topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (case_id < num_tokens) {
|
||||
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
|
||||
scores[i] = 0;
|
||||
}
|
||||
}
|
||||
__threadfence();
|
||||
__syncthreads();
|
||||
__syncwarp();
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
|
||||
scores[s_topk_idx[i]] = value;
|
||||
if (if_proceed_next_topk) {
|
||||
if (if_proceed_next_topk) {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float value;
|
||||
if (renormalize) {
|
||||
value = cuda_cast<float, T>(s_topk_value[i]) / topk_sum *
|
||||
routed_scaling_factor;
|
||||
} else {
|
||||
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
|
||||
}
|
||||
scores[s_topk_idx[i]] = value;
|
||||
topk_indices[i] = s_topk_idx[i];
|
||||
topk_values[i] = static_cast<T>(value);
|
||||
topk_values[i] = cuda_cast<T, float>(value);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
topk_indices[i] = i;
|
||||
topk_values[i] = static_cast<float>(1.0f / topk);
|
||||
topk_values[i] = cuda_cast<T, float>(1.0f / topk);
|
||||
}
|
||||
}
|
||||
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
|
||||
// default result.
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
@@ -518,17 +653,24 @@ void invokeNoAuxTc(T* scores,
|
||||
int64_t const n_group,
|
||||
int64_t const topk_group,
|
||||
int64_t const topk,
|
||||
bool const renormalize,
|
||||
double const routed_scaling_factor,
|
||||
cudaStream_t const stream) {
|
||||
int64_t num_cases = num_tokens * n_group;
|
||||
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
|
||||
group_scores,
|
||||
scores_with_bias,
|
||||
num_tokens,
|
||||
num_cases,
|
||||
n_group,
|
||||
num_experts / n_group);
|
||||
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = topk_with_k2_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = false;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
|
||||
num_tokens, num_cases, n_group, num_experts / n_group);
|
||||
|
||||
int64_t topk_with_k_group_num_blocks =
|
||||
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
@@ -536,21 +678,19 @@ void invokeNoAuxTc(T* scores,
|
||||
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
||||
topk);
|
||||
|
||||
group_idx_and_topk_idx_kernel<T><<<topk_with_k_group_num_blocks,
|
||||
BLOCK_SIZE,
|
||||
dynamic_smem_in_bytes,
|
||||
stream>>>(scores,
|
||||
group_scores,
|
||||
topk_values,
|
||||
topk_indices,
|
||||
scores_with_bias,
|
||||
num_tokens,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
num_experts,
|
||||
num_experts / n_group,
|
||||
routed_scaling_factor);
|
||||
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
|
||||
config.gridDim = topk_with_k_group_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = dynamic_smem_in_bytes;
|
||||
config.stream = stream;
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = false;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
|
||||
topk_values, topk_indices, scores_with_bias, num_tokens,
|
||||
n_group, topk_group, topk, num_experts,
|
||||
num_experts / n_group, renormalize, routed_scaling_factor);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
||||
@@ -564,6 +704,7 @@ void invokeNoAuxTc(T* scores,
|
||||
int64_t const n_group, \
|
||||
int64_t const topk_group, \
|
||||
int64_t const topk, \
|
||||
bool const renormalize, \
|
||||
double const routed_scaling_factor, \
|
||||
cudaStream_t const stream);
|
||||
|
||||
|
@@ -3,6 +3,158 @@
|
||||
|
||||
#include "quantization/common.cuh"
|
||||
|
||||
// adapted from: https://github.com/sgl-project/sglang/blob/v0.5.2rc2/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 1. Warp‑local, no shared memory
|
||||
// • One warp handles one token.
|
||||
// • Eight tokens per 256‑thread CTA.
|
||||
// ---------------------------------------------------------------------------
|
||||
template <typename T, typename DST_DTYPE, int kTokensPerCTA = 8, int kVecSize = 16>
|
||||
__global__ void per_token_quant_fp8_kernel(
|
||||
const T* __restrict__ input,
|
||||
DST_DTYPE* __restrict__ output_q,
|
||||
float* __restrict__ output_s,
|
||||
const float scale_ub,
|
||||
const int64_t hidden_size,
|
||||
const int64_t num_tokens) {
|
||||
const int warp_id = threadIdx.x / WARP_SIZE; // 0‑7 (8 warps)
|
||||
const int lane_id = threadIdx.x & (WARP_SIZE - 1); // 0‑31
|
||||
const int token_id = blockIdx.x * kTokensPerCTA + warp_id;
|
||||
if (token_id >= num_tokens) return;
|
||||
|
||||
// Global tensors for this token
|
||||
const T* token_input = input + token_id * hidden_size;
|
||||
DST_DTYPE* token_output = output_q + token_id * hidden_size;
|
||||
float* token_scale = output_s + token_id;
|
||||
|
||||
//
|
||||
// Pass-1: Perform a warp reduce to find the max_value of a token's hidden_size
|
||||
//
|
||||
float max_value = 0.f;
|
||||
using vec_t = AlignedVector<T, kVecSize>;
|
||||
const int32_t num_vec_elems = hidden_size / kVecSize;
|
||||
|
||||
for (int32_t i = lane_id; i < num_vec_elems; i += WARP_SIZE) {
|
||||
vec_t input_vec;
|
||||
Load(token_input + i * kVecSize, &input_vec);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
max_value = fmaxf(max_value, fabsf(static_cast<float>(input_vec[j])));
|
||||
}
|
||||
}
|
||||
|
||||
float warp_max = warpReduceMax(max_value);
|
||||
if (scale_ub > 0){
|
||||
warp_max = fminf(warp_max, scale_ub);
|
||||
}
|
||||
float scale;
|
||||
scale = warp_max / FP8_E4M3_MAX;
|
||||
// Broadcast scale
|
||||
if (lane_id == 0) {
|
||||
token_scale[0] = scale;
|
||||
}
|
||||
float scale_inv = (scale == 0.f) ? 0.f : 1.0f / scale;
|
||||
|
||||
//
|
||||
// Pass-2: quantize and write back
|
||||
//
|
||||
for (int i = lane_id; i < num_vec_elems; i += WARP_SIZE) {
|
||||
vec_t input_vec;
|
||||
Load(token_input + i * kVecSize, &input_vec);
|
||||
DST_DTYPE output_arr[kVecSize];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
float val = static_cast<float>(input_vec[j]) * scale_inv;
|
||||
val = fmaxf(fminf(val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
output_arr[j] = static_cast<DST_DTYPE>(val);
|
||||
}
|
||||
if constexpr (kVecSize == 16) {
|
||||
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
|
||||
} else {
|
||||
// Use element-wise copy for vector size 8 to ensure correctness
|
||||
for (int k = 0; k < kVecSize; ++k) {
|
||||
token_output[i * kVecSize + k] = output_arr[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 2. Baseline kernel (1 token / CTA, CUB block reduce)
|
||||
// ---------------------------------------------------------------------------
|
||||
template <typename T, typename DST_DTYPE, int kVecSize = 16>
|
||||
__global__ void per_token_quant_fp8_small_batch_kernel(
|
||||
const T* __restrict__ input,
|
||||
DST_DTYPE* __restrict__ output_q,
|
||||
float* __restrict__ output_s,
|
||||
const float scale_ub,
|
||||
const int64_t hidden_size,
|
||||
const int64_t num_tokens) {
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= num_tokens) return;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int block_dim = blockDim.x;
|
||||
|
||||
const T* token_input = input + token_idx * hidden_size;
|
||||
DST_DTYPE* token_output = output_q + token_idx * hidden_size;
|
||||
|
||||
float max_value = 0.0f;
|
||||
|
||||
// Use template parameter for vector size
|
||||
using vec_t = AlignedVector<T, kVecSize>;
|
||||
const int32_t num_vec_elems = hidden_size / kVecSize;
|
||||
|
||||
// Find max using vectorized loads
|
||||
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||
vec_t input_vec;
|
||||
Load(token_input + i * kVecSize, &input_vec);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
float val = static_cast<float>(input_vec[j]);
|
||||
max_value = fmaxf(max_value, fabsf(val));
|
||||
}
|
||||
}
|
||||
|
||||
max_value = blockReduceMax(max_value);
|
||||
if (scale_ub > 0){
|
||||
max_value = fminf(max_value, scale_ub);
|
||||
}
|
||||
__shared__ float scale;
|
||||
if (tid == 0) {
|
||||
scale = max_value / FP8_E4M3_MAX;
|
||||
output_s[token_idx] = scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const float scale_inv = 1.0f / scale;
|
||||
|
||||
// Quantize using vectorized loads
|
||||
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||
vec_t input_vec;
|
||||
Load(token_input + i * kVecSize, &input_vec);
|
||||
|
||||
DST_DTYPE output_arr[kVecSize];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
output_arr[j] = static_cast<DST_DTYPE>(val);
|
||||
}
|
||||
|
||||
if constexpr (kVecSize == 16) {
|
||||
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
|
||||
} else {
|
||||
// Use element-wise copy for vector size 8 to ensure correctness
|
||||
for (int k = 0; k < kVecSize; ++k) {
|
||||
token_output[i * kVecSize + k] = output_arr[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
@@ -179,39 +331,78 @@ void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out, // [..., d]
|
||||
auto rank = input.dims().size();
|
||||
int const hidden_size = input.dims()[rank - 1];
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
cudaStream_t stream = input.stream();
|
||||
|
||||
if (hidden_size % 8 == 0){
|
||||
int device = 0;
|
||||
cudaGetDevice(&device);
|
||||
int sm_count = 0;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device);
|
||||
const int TOKENS_PER_CTA = 8;
|
||||
const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA);
|
||||
const bool use_vec16 = (hidden_size % 16 == 0);
|
||||
DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
|
||||
if (use_warp_kernel) {
|
||||
// -------- warp‑local ---------------------------------------------------
|
||||
constexpr int THREADS = TOKENS_PER_CTA * WARP_SIZE; // 256
|
||||
dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA);
|
||||
dim3 block(THREADS);
|
||||
|
||||
if (use_vec16) {
|
||||
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||
reinterpret_cast<float*>(scales.data<float>()),
|
||||
scale_ub,
|
||||
hidden_size,
|
||||
num_tokens);
|
||||
} else {
|
||||
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||
reinterpret_cast<float*>(scales.data<float>()),
|
||||
scale_ub,
|
||||
hidden_size,
|
||||
num_tokens);
|
||||
}
|
||||
} else {
|
||||
// -------- baseline -----------------------------------------------------
|
||||
constexpr int THREADS = 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(THREADS);
|
||||
|
||||
if (use_vec16) {
|
||||
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3, 16><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||
reinterpret_cast<float*>(scales.data<float>()),
|
||||
scale_ub,
|
||||
hidden_size,
|
||||
num_tokens);
|
||||
} else {
|
||||
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3, 8><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||
reinterpret_cast<float*>(scales.data<float>()),
|
||||
scale_ub,
|
||||
hidden_size,
|
||||
num_tokens);
|
||||
}
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(std::min(hidden_size, 1024));
|
||||
|
||||
cudaStream_t stream = input.stream();
|
||||
DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
|
||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||
input.data<scalar_t>(), scale_ub,
|
||||
hidden_size);
|
||||
});
|
||||
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::FLOAT32: {
|
||||
using scalar_t = float;
|
||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||
input.data<scalar_t>(), scale_ub,
|
||||
hidden_size);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
using scalar_t = phi::dtype::float16;
|
||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||
input.data<scalar_t>(), scale_ub,
|
||||
hidden_size);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
using scalar_t = phi::dtype::bfloat16;
|
||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||
input.data<scalar_t>(), scale_ub,
|
||||
hidden_size);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16].");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(static_scaled_fp8_quant)
|
||||
|
@@ -59,7 +59,7 @@ __global__ void text_image_scatter_kernel(
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using T_Vec = AlignedVector<T, VecSize>;
|
||||
T_Vec input_ptr_vec;
|
||||
T_Vec text_imgaes_vec;
|
||||
T_Vec text_images_vec;
|
||||
|
||||
int64_t global_thread_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int64_t step = blockDim.x * gridDim.x * VecSize;
|
||||
@@ -76,16 +76,20 @@ __global__ void text_image_scatter_kernel(
|
||||
Load<T, VecSize>(input_ptr + input_load_offset, &input_ptr_vec);
|
||||
#pragma unroll
|
||||
for(int vi = 0; vi < VecSize; ++vi) {
|
||||
text_imgaes_vec[vi] = input_ptr_vec[vi];
|
||||
text_images_vec[vi] = input_ptr_vec[vi];
|
||||
}
|
||||
|
||||
if (token_type_ids_num == 0) {
|
||||
int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset;
|
||||
Store<T,VecSize>(text_imgaes_vec, text_gather_ptr + text_load_offset);
|
||||
Store<T,VecSize>(text_images_vec, text_gather_ptr + text_load_offset);
|
||||
|
||||
} else if(token_type_ids_num == 1){
|
||||
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
|
||||
Store<T,VecSize>(text_images_vec, image_gather_ptr + image_load_offset);
|
||||
|
||||
} else {
|
||||
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
|
||||
Store<T,VecSize>(text_imgaes_vec, image_gather_ptr + image_load_offset);
|
||||
// skip cuda graph padding value
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -120,9 +124,12 @@ __global__ void text_image_gather_kernel(
|
||||
int64_t text_load_offset = text_index[token_idx] * hidden_size + hidden_offset;
|
||||
Load<T,VecSize>(text_gather_ptr + text_load_offset, &text_imgaes_vec);
|
||||
|
||||
} else {
|
||||
} else if (token_type_ids_num == 1){
|
||||
int64_t image_load_offset = image_index[token_idx] * hidden_size + hidden_offset;
|
||||
Load<T,VecSize>(image_gather_ptr + image_load_offset, &text_imgaes_vec);
|
||||
} else {
|
||||
// skip cuda graph padding value
|
||||
continue;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@@ -154,7 +161,6 @@ void LaunchTextImageGatherScatter(
|
||||
const int64_t token_num = in_dims[0];
|
||||
const int64_t hidden_size = in_dims[1];
|
||||
|
||||
|
||||
const int VecSize = 16 / sizeof(data_t);
|
||||
const int64_t tot_element_num = token_num * hidden_size;
|
||||
|
||||
@@ -168,7 +174,7 @@ void LaunchTextImageGatherScatter(
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(GetGridSize(tot_pack_num, block_size, kNumWaves, &grid_size_x));
|
||||
dim3 grid_dim = dim3(grid_size_x, 1, 1);
|
||||
if (is_scatter) {
|
||||
text_image_scatter_kernel<DataType_, 8><<<grid_dim, block_size>>>(
|
||||
text_image_scatter_kernel<DataType_, VecSize><<<grid_dim, block_size, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(input.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
|
||||
@@ -179,7 +185,7 @@ void LaunchTextImageGatherScatter(
|
||||
tot_element_num
|
||||
);
|
||||
} else {
|
||||
text_image_gather_kernel<DataType_, 8><<<grid_dim, block_size>>>(
|
||||
text_image_gather_kernel<DataType_, VecSize><<<grid_dim, block_size, 0, stream>>>(
|
||||
reinterpret_cast<DataType_*>(input.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(text_input.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(image_input.data<data_t>()),
|
||||
|
@@ -16,7 +16,7 @@
|
||||
|
||||
template <int VecSize>
|
||||
__global__ void text_image_index_out_kernel(
|
||||
int32_t* token_type_ids,
|
||||
const int32_t* token_type_ids,
|
||||
int32_t* text_index,
|
||||
int32_t* image_index,
|
||||
const int64_t token_num
|
||||
@@ -31,23 +31,27 @@ __global__ void text_image_index_out_kernel(
|
||||
if (token_type_ids[i] == 0) {
|
||||
text_index[i] = text_count;
|
||||
text_count += 1;
|
||||
} else {
|
||||
} else if (token_type_ids[i] == 1) {
|
||||
image_index[i] = images_count;
|
||||
images_count += 1;
|
||||
} else {
|
||||
// skip cuda graph padding value
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TextImageIndexOut(
|
||||
const paddle::Tensor& token_type_ids,
|
||||
const paddle::Tensor& text_index,
|
||||
const paddle::Tensor& image_index) {
|
||||
paddle::Tensor& text_index,
|
||||
paddle::Tensor& image_index) {
|
||||
|
||||
const int64_t token_num = token_type_ids.shape()[0];
|
||||
text_image_index_out_kernel<1><<<1, 1>>>(
|
||||
const_cast<int32_t*>(token_type_ids.data<int32_t>()),
|
||||
const_cast<int32_t*>(text_index.data<int32_t>()),
|
||||
const_cast<int32_t*>(image_index.data<int32_t>()),
|
||||
auto stream = token_type_ids.stream();
|
||||
text_image_index_out_kernel<1><<<1, 1, 0, stream>>>(
|
||||
token_type_ids.data<int32_t>(),
|
||||
text_index.data<int32_t>(),
|
||||
image_index.data<int32_t>(),
|
||||
token_num
|
||||
);
|
||||
}
|
||||
|
71
custom_ops/gpu_ops/unset_data_ipc.cu
Normal file
71
custom_ops/gpu_ops/unset_data_ipc.cu
Normal file
@@ -0,0 +1,71 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "cuda_multiprocess.h"
|
||||
|
||||
#if !defined(_WIN32)
|
||||
#include <errno.h>
|
||||
#include <string.h>
|
||||
#include <fcntl.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#endif
|
||||
|
||||
// 可选:仅删除/解除共享内存命名对象(不依赖之前保存的 addr/fd)
|
||||
static inline int sharedMemoryUnlinkByName(const char* name) {
|
||||
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
|
||||
// Windows 上没有 shm_unlink 语义。命名对象在最后一个句柄关闭后消失。
|
||||
// 这里做“尽力而为”:尝试打开后立即关闭,减少一次引用。
|
||||
HANDLE hMap = OpenFileMappingA(FILE_MAP_ALL_ACCESS, FALSE, name);
|
||||
if (hMap) {
|
||||
CloseHandle(hMap);
|
||||
return 0;
|
||||
}
|
||||
// 已经不存在也算成功
|
||||
return 0;
|
||||
#else
|
||||
// POSIX: 移除名字,未来不可再 open;已映射区仍存活直至 munmap
|
||||
if (shm_unlink(name) != 0) {
|
||||
if (errno == ENOENT) return 0; // 不存在视作成功
|
||||
return errno;
|
||||
}
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
void UnsetDataIpc(const paddle::Tensor& tmp_input,
|
||||
const std::string& shm_name,
|
||||
bool close_ipc,
|
||||
bool unlink_shm) {
|
||||
// 1) 关闭消费者导入的 IPC 映射(仅当 close_ipc=true 且该指针确为 OpenMemHandle 得来)
|
||||
if (close_ipc) {
|
||||
void* ptr = const_cast<void*>(tmp_input.data());
|
||||
checkCudaErrors(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
|
||||
// 2) 解除共享内存命名对象(仅处理“名字”,不保证解除旧映射)
|
||||
if (unlink_shm) {
|
||||
int rc = sharedMemoryUnlinkByName(shm_name.c_str());
|
||||
if (rc != 0) {
|
||||
PD_THROW("Unlink shared memory failed: name=%s, err=%d",
|
||||
shm_name.c_str(), rc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(unset_data_ipc)
|
||||
.Inputs({"tmp_input"})
|
||||
.Attrs({"shm_name: std::string", "close_ipc: bool", "unlink_shm: bool"})
|
||||
.SetKernelFn(PD_KERNEL(UnsetDataIpc));
|
@@ -32,7 +32,8 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
const int block_size,
|
||||
bool prefill_one_step_stop) {
|
||||
int thread_idx = threadIdx.x;
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
@@ -54,23 +55,32 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
} else {
|
||||
if (seq_lens_this_time[thread_idx] + seq_lens_decoder[thread_idx] >= prompt_lens[thread_idx]) {
|
||||
// decoding
|
||||
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
|
||||
seq_lens_this_time[thread_idx] = 1;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
|
||||
input_ids_now[0] = next_tokens[thread_idx];
|
||||
if (prefill_one_step_stop) {
|
||||
// prefill done, stop
|
||||
stop_flags[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
} else{
|
||||
// decoding
|
||||
seq_lens_decoder[thread_idx] += seq_lens_this_time[thread_idx];
|
||||
seq_lens_this_time[thread_idx] = 1;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride;
|
||||
input_ids_now[0] = next_tokens[thread_idx];
|
||||
|
||||
// to judge whether block is not enough
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) {
|
||||
// should be scheduled by server
|
||||
is_block_step[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx]= 0;
|
||||
stop_flags[thread_idx] = true;
|
||||
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx];
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
// to judge whether block is not enough
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
if (seq_lens_this_time[thread_idx] != 0 && block_table_now[seq_lens_decoder[thread_idx] / block_size] == -1) {
|
||||
// should be scheduled by server
|
||||
is_block_step[thread_idx] = true;
|
||||
seq_lens_this_time[thread_idx]= 0;
|
||||
stop_flags[thread_idx] = true;
|
||||
step_seq_lens_decoder[thread_idx] = seq_lens_decoder[thread_idx];
|
||||
seq_lens_decoder[thread_idx] = 0;
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
}
|
||||
} else
|
||||
{
|
||||
@@ -110,6 +120,12 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
|
||||
#else
|
||||
auto cu_stream = input_ids.stream();
|
||||
#endif
|
||||
bool prefill_one_step_stop = false;
|
||||
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
|
||||
if (env_p[0] == '1') {
|
||||
prefill_one_step_stop = true;
|
||||
}
|
||||
}
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int now_bsz = seq_lens_this_time.shape()[0];
|
||||
const int input_ids_stride = input_ids.shape()[1];
|
||||
@@ -133,7 +149,8 @@ void UpdateInputesV1(const paddle::Tensor &stop_flags,
|
||||
max_bsz,
|
||||
input_ids_stride,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
block_size,
|
||||
prefill_one_step_stop);
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
|
376
custom_ops/iluvatar_ops/mixed_fused_attn.cu
Normal file
376
custom_ops/iluvatar_ops/mixed_fused_attn.cu
Normal file
@@ -0,0 +1,376 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "iluvatar_context.h"
|
||||
|
||||
template <paddle::DataType T>
|
||||
void MixedFusedPagedAttnKernel(const paddle::Tensor& qkv,
|
||||
paddle::Tensor& k_cache,
|
||||
paddle::Tensor& v_cache,
|
||||
const paddle::Tensor& prefill_block_table,
|
||||
const paddle::Tensor& decode_block_table,
|
||||
const paddle::Tensor& cu_seqlens_qkv,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::optional<paddle::Tensor> &rope_sin,
|
||||
const paddle::optional<paddle::Tensor> &rope_cos,
|
||||
int prefill_num_tokens,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
float scale,
|
||||
bool causal,
|
||||
bool q_rope,
|
||||
bool k_rope,
|
||||
bool v_rope,
|
||||
int window_left,
|
||||
int window_right,
|
||||
float softcap,
|
||||
bool enable_cuda_graph,
|
||||
bool use_sqrt_alibi,
|
||||
paddle::Tensor& out) {
|
||||
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
const auto& dtype = qkv.dtype();
|
||||
cuinferDataType_t cuinfer_data_type;
|
||||
cudaDataType_t cu_data_type;
|
||||
if (dtype == paddle::DataType::FLOAT16) {
|
||||
cuinfer_data_type = CUINFER_DATA_HALF;
|
||||
cu_data_type = CUDA_R_16F;
|
||||
} else {
|
||||
cuinfer_data_type = CUINFER_DATA_BFLOAT16;
|
||||
cu_data_type = CUDA_R_16BF;
|
||||
}
|
||||
|
||||
const auto& qkv_dims = qkv.dims();
|
||||
const auto& kv_cache_dims = k_cache.dims();
|
||||
const auto& prefill_block_table_dims = prefill_block_table.dims();
|
||||
const auto& cu_seqlens_qkv_dims = cu_seqlens_qkv.dims();
|
||||
|
||||
int prefill_batch_size = prefill_block_table_dims[0];
|
||||
int num_tokens = qkv_dims[0];
|
||||
int decode_num_tokens = num_tokens - prefill_num_tokens;
|
||||
int num_total_heads = num_heads + 2 * num_kv_heads;
|
||||
int max_num_blocks_per_seq = prefill_block_table_dims[1];
|
||||
int qkv_stride = qkv.strides()[0];
|
||||
int num_blocks = kv_cache_dims[0];
|
||||
|
||||
int kv_block_stride = k_cache.strides()[0];
|
||||
int kv_head_stride = k_cache.strides()[1];
|
||||
int block_table_stride = prefill_block_table.strides()[0];
|
||||
const float *rope_sin_ptr = rope_sin ? rope_sin.get().data<float>() : nullptr;
|
||||
const float *rope_cos_ptr = rope_cos ? rope_cos.get().data<float>() : nullptr;
|
||||
|
||||
cuinferTensorDescriptor_t qkv_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
qkv_desc,
|
||||
cuinfer_data_type,
|
||||
3,
|
||||
std::vector<int>({prefill_num_tokens, num_total_heads, head_dim}).data(),
|
||||
std::vector<int>({num_total_heads * head_dim, head_dim, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t qkv_seqlens_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_seqlens_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
qkv_seqlens_desc,
|
||||
CUINFER_DATA_INT32,
|
||||
1,
|
||||
std::vector<int>({prefill_batch_size + 1}).data(),
|
||||
std::vector<int>({1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t block_table_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&block_table_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
block_table_desc,
|
||||
CUINFER_DATA_INT32,
|
||||
2,
|
||||
std::vector<int>({prefill_batch_size, block_table_stride}).data(),
|
||||
std::vector<int>({block_table_stride, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t o_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&o_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
o_desc,
|
||||
cuinfer_data_type,
|
||||
3,
|
||||
std::vector<int>({prefill_num_tokens, num_heads, head_dim}).data(),
|
||||
std::vector<int>({num_heads * head_dim, head_dim, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t k_cache_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&k_cache_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
k_cache_desc,
|
||||
cuinfer_data_type,
|
||||
4,
|
||||
std::vector<int>({num_blocks, num_kv_heads, block_size, head_dim}).data(),
|
||||
std::vector<int>({num_kv_heads * block_size * head_dim, block_size * head_dim, head_dim, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t v_cache_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&v_cache_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
v_cache_desc,
|
||||
cuinfer_data_type,
|
||||
4,
|
||||
std::vector<int>({num_blocks, num_kv_heads, block_size, head_dim}).data(),
|
||||
std::vector<int>({num_kv_heads * block_size * head_dim, block_size * head_dim, head_dim, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t cos_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cos_desc,
|
||||
CUINFER_DATA_FLOAT,
|
||||
2,
|
||||
std::vector<int>({max_seq_len, head_dim}).data(),
|
||||
std::vector<int>({head_dim, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t sin_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
sin_desc,
|
||||
CUINFER_DATA_FLOAT,
|
||||
2,
|
||||
std::vector<int>({max_seq_len, head_dim}).data(),
|
||||
std::vector<int>({head_dim, 1}).data()));
|
||||
|
||||
cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle();
|
||||
|
||||
size_t prefill_workspace_size = 0;
|
||||
CUINFER_CHECK(cuinferGetFmhaFwdMergedFuseRopeWorkspaceSize(prefill_num_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
cuinfer_data_type,
|
||||
cuinfer_data_type,
|
||||
cuinfer_data_type,
|
||||
&prefill_workspace_size));
|
||||
|
||||
auto* allocator = paddle::GetAllocator(qkv.place());
|
||||
|
||||
phi::Allocator::AllocationPtr prefill_tmp_workspace = allocator->Allocate(prefill_workspace_size);
|
||||
void* prefill_workspace_ptr = prefill_tmp_workspace->ptr();
|
||||
|
||||
CUINFER_CHECK(cuinferFmhaFwdMergedFuseRopeFunc(cuinfer_handle,
|
||||
qkv_desc,
|
||||
qkv.data(),
|
||||
qkv_seqlens_desc,
|
||||
cu_seqlens_qkv.data<int32_t>(),
|
||||
block_table_desc,
|
||||
prefill_block_table.data<int32_t>(),
|
||||
o_desc,
|
||||
out.data(),
|
||||
k_cache_desc,
|
||||
k_cache.data(),
|
||||
v_cache_desc,
|
||||
v_cache.data(),
|
||||
prefill_workspace_ptr,
|
||||
prefill_workspace_size,
|
||||
cos_desc,
|
||||
rope_cos_ptr,
|
||||
sin_desc,
|
||||
rope_sin_ptr,
|
||||
prefill_batch_size,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
causal,
|
||||
scale,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope));
|
||||
|
||||
size_t decode_workspace_size = 0;
|
||||
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(decode_num_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
&decode_workspace_size));
|
||||
|
||||
phi::Allocator::AllocationPtr decode_tmp_workspace = allocator->Allocate(decode_workspace_size);
|
||||
void* decode_workspace_ptr = decode_tmp_workspace->ptr();
|
||||
|
||||
void* decode_qkv_ptr = (void*)(qkv.data<data_t>() + prefill_num_tokens * qkv_stride);
|
||||
void* decode_out_ptr = (void*)(out.data<data_t>() + prefill_num_tokens * out.strides()[0]);
|
||||
|
||||
PageAttentionWithKVCacheArguments args{
|
||||
static_cast<float>(scale), 1.0, 1.0, static_cast<float>(softcap), window_left, window_right,
|
||||
causal, use_sqrt_alibi, enable_cuda_graph, false, nullptr, decode_qkv_ptr, decode_qkv_ptr,
|
||||
decode_workspace_ptr, true, rope_sin_ptr, rope_cos_ptr};
|
||||
|
||||
CUINFER_CHECK(cuInferPageAttentionV7(cuinfer_handle,
|
||||
decode_out_ptr,
|
||||
cu_data_type,
|
||||
decode_qkv_ptr,
|
||||
cu_data_type,
|
||||
decode_num_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
qkv_stride,
|
||||
kv_block_stride,
|
||||
kv_head_stride,
|
||||
k_cache.data(),
|
||||
cu_data_type,
|
||||
v_cache.data(),
|
||||
cu_data_type,
|
||||
block_size,
|
||||
max_num_blocks_per_seq,
|
||||
max_seq_len,
|
||||
decode_block_table.data<int32_t>(),
|
||||
seq_lens.data<int32_t>(),
|
||||
args));
|
||||
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_seqlens_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(block_table_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(o_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(k_cache_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(v_cache_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(cos_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(sin_desc));
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MixedFusedPagedAttn(const paddle::Tensor& qkv,
|
||||
paddle::Tensor& k_cache,
|
||||
paddle::Tensor& v_cache,
|
||||
const paddle::Tensor& prefill_block_table,
|
||||
const paddle::Tensor& decode_block_table,
|
||||
const paddle::Tensor& cu_seqlens_qkv,
|
||||
const paddle::Tensor& seq_lens,
|
||||
const paddle::optional<paddle::Tensor> &rope_sin,
|
||||
const paddle::optional<paddle::Tensor> &rope_cos,
|
||||
int prefill_num_tokens,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
float scale,
|
||||
bool causal,
|
||||
bool q_rope,
|
||||
bool k_rope,
|
||||
bool v_rope,
|
||||
int window_left,
|
||||
int window_right,
|
||||
float softcap,
|
||||
bool enable_cuda_graph,
|
||||
bool use_sqrt_alibi) {
|
||||
const auto dtype = qkv.dtype();
|
||||
auto out = paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place());
|
||||
|
||||
switch (dtype) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MixedFusedPagedAttnKernel<paddle::DataType::BFLOAT16>(qkv,
|
||||
k_cache,
|
||||
v_cache,
|
||||
prefill_block_table,
|
||||
decode_block_table,
|
||||
cu_seqlens_qkv,
|
||||
seq_lens,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
prefill_num_tokens,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
scale,
|
||||
causal,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
window_left,
|
||||
window_right,
|
||||
softcap,
|
||||
enable_cuda_graph,
|
||||
use_sqrt_alibi,
|
||||
out);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
MixedFusedPagedAttnKernel<paddle::DataType::FLOAT16>(qkv,
|
||||
k_cache,
|
||||
v_cache,
|
||||
prefill_block_table,
|
||||
decode_block_table,
|
||||
cu_seqlens_qkv,
|
||||
seq_lens,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
prefill_num_tokens,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
scale,
|
||||
causal,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
window_left,
|
||||
window_right,
|
||||
softcap,
|
||||
enable_cuda_graph,
|
||||
use_sqrt_alibi,
|
||||
out);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for mixed paged attn");
|
||||
}
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MixedFusedPagedAttnInferShape(const std::vector<int64_t>& qkv_shape,
|
||||
int num_heads,
|
||||
int head_dim) {
|
||||
return {{qkv_shape[0], num_heads * head_dim}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MixedFusedPagedAttnInferDtype(const paddle::DataType& qkv_dtype) {
|
||||
return {qkv_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(mixed_fused_paged_attn)
|
||||
.Inputs({"qkv", "k_cache", "v_cache", "prefill_block_table", "decode_block_table",
|
||||
"cu_seqlens_qkv", "seq_lens", paddle::Optional("rope_sin"), paddle::Optional("rope_cos")})
|
||||
.Outputs({"out"})
|
||||
.Attrs({"prefill_num_tokens:int",
|
||||
"num_heads: int",
|
||||
"head_dim:int",
|
||||
"num_kv_heads:int",
|
||||
"block_size:int",
|
||||
"max_seq_len:int",
|
||||
"scale:float",
|
||||
"causal:bool",
|
||||
"q_rope:bool",
|
||||
"k_rope:bool",
|
||||
"v_rope:bool",
|
||||
"window_left:int",
|
||||
"window_right:int",
|
||||
"softcap:float",
|
||||
"enable_cuda_graph:bool",
|
||||
"use_sqrt_alibi:bool"})
|
||||
.SetKernelFn(PD_KERNEL(MixedFusedPagedAttn))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MixedFusedPagedAttnInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MixedFusedPagedAttnInferDtype));
|
@@ -53,6 +53,7 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
const paddle::optional<paddle::Tensor>& gating_correction_bias,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const std::string &moe_quant_type,
|
||||
const bool topk_only_mode,
|
||||
const int num_rows,
|
||||
const int hidden_size,
|
||||
@@ -183,6 +184,7 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
const paddle::optional<paddle::Tensor>& w4a8_in_scale,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const std::string &moe_quant_type,
|
||||
const bool topk_only_mode) {
|
||||
const auto input_type = input.dtype();
|
||||
auto place = input.place();
|
||||
@@ -220,6 +222,7 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
gating_correction_bias,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
moe_quant_type,
|
||||
topk_only_mode,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
@@ -236,6 +239,7 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
gating_correction_bias,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
moe_quant_type,
|
||||
topk_only_mode,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
@@ -305,7 +309,7 @@ PD_BUILD_STATIC_OP(moe_expert_dispatch)
|
||||
"top_k_weight",
|
||||
"top_k_indices",
|
||||
"expert_idx_per_token"})
|
||||
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
|
||||
.Attrs({"moe_topk:int", "group_moe:bool", "moe_quant_type:std::string", "topk_only_mode:bool"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));
|
||||
|
@@ -27,6 +27,8 @@ void PagedAttnKernel(const paddle::Tensor& q,
|
||||
const paddle::optional<paddle::Tensor> &v,
|
||||
const paddle::optional<paddle::Tensor> &rope_sin,
|
||||
const paddle::optional<paddle::Tensor> &rope_cos,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
int block_size,
|
||||
@@ -86,32 +88,36 @@ void PagedAttnKernel(const paddle::Tensor& q,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects seq_lens is contiguous"));
|
||||
// check dim and shape
|
||||
// k_cache: [num_blocks, kv_num_heads, block_size, head_size]
|
||||
// v_cache: [num_blocks, kv_num_heads, block_size, head_size]
|
||||
// k_cache: [num_blocks, kv_num_heads, block_size, head_dim]
|
||||
// v_cache: [num_blocks, kv_num_heads, block_size, head_dim]
|
||||
// block_table: [num_seqs, max_num_blocks_per_seq]
|
||||
// seq_lens: [num_seqs]
|
||||
// q and out:
|
||||
// merged_qkv = false: [num_seqs, num_heads, head_size]
|
||||
// merged_qkv = true: [num_seqs, num_heads+2*num_kv_heads, head_size]
|
||||
// if merged_qkv = false:
|
||||
// q:[num_seqs, hidden_size]
|
||||
// out:[num_seqs, hidden_size]
|
||||
// if merged_qkv = true:
|
||||
// q: [num_seqs, (num_heads+2*num_kv_heads)*head_dim]
|
||||
// out: [num_seqs, hidden_size]
|
||||
|
||||
const auto& q_dims = q.dims();
|
||||
PADDLE_ENFORCE_EQ(q_dims.size(),
|
||||
3,
|
||||
2,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive query dims is "
|
||||
"[num_seqs, num_heads, head_size]"));
|
||||
"[num_seqs, (num_heads+2*num_kv_heads)*head_dim]"));
|
||||
PADDLE_ENFORCE_EQ(out.dims().size(),
|
||||
3,
|
||||
2,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive out dims is "
|
||||
"[num_seqs, num_heads, head_size]"));
|
||||
"[num_seqs, hidden_size]"));
|
||||
|
||||
const auto& kv_cache_dims = k_cache.dims();
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims.size(),
|
||||
4,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive kv cache dims is "
|
||||
"[num_blocks, kv_num_heads, block_size, head_size]"));
|
||||
"[num_blocks, kv_num_heads, block_size, head_dim]"));
|
||||
|
||||
const auto& block_table_dims = block_table.dims();
|
||||
PADDLE_ENFORCE_EQ(block_table_dims.size(),
|
||||
@@ -127,8 +133,6 @@ void PagedAttnKernel(const paddle::Tensor& q,
|
||||
"paged_attn receive seq_lens dims is [num_seqs]"));
|
||||
|
||||
int num_seqs = q_dims[0];
|
||||
int num_heads = merged_qkv ? q_dims[1] - 2 * num_kv_heads : q_dims[1];
|
||||
int head_size = q_dims[2];
|
||||
int max_num_blocks_per_seq = block_table_dims[1];
|
||||
int q_stride = q.strides()[0];
|
||||
int num_blocks = kv_cache_dims[0];
|
||||
@@ -142,9 +146,9 @@ void PagedAttnKernel(const paddle::Tensor& q,
|
||||
common::errors::InvalidArgument(
|
||||
"kv_cache_dims[2] must be equal to block_size"));
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims[3],
|
||||
head_size,
|
||||
head_dim,
|
||||
common::errors::InvalidArgument(
|
||||
"kv_cache_dims[3] must be equal to head_size"));
|
||||
"kv_cache_dims[3] must be equal to head_dim"));
|
||||
PADDLE_ENFORCE_EQ(block_table_dims[0],
|
||||
num_seqs,
|
||||
common::errors::InvalidArgument(
|
||||
@@ -162,14 +166,13 @@ void PagedAttnKernel(const paddle::Tensor& q,
|
||||
const float *rope_sin_ptr = merged_qkv ? rope_sin.get().data<float>() : nullptr;
|
||||
const float *rope_cos_ptr = merged_qkv ? rope_cos.get().data<float>() : nullptr;
|
||||
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(q.place()));
|
||||
cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle();
|
||||
|
||||
size_t workspace_size = 0;
|
||||
CUINFER_CHECK(cuInferPageAttentionGetWorkspaceV7(num_seqs,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
head_dim,
|
||||
block_size,
|
||||
max_context_len,
|
||||
&workspace_size));
|
||||
@@ -189,7 +192,7 @@ void PagedAttnKernel(const paddle::Tensor& q,
|
||||
num_seqs,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
head_dim,
|
||||
q_stride,
|
||||
kv_block_stride,
|
||||
kv_head_stride,
|
||||
@@ -215,6 +218,8 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
|
||||
const paddle::optional<paddle::Tensor> &v,
|
||||
const paddle::optional<paddle::Tensor> &rope_sin,
|
||||
const paddle::optional<paddle::Tensor> &rope_cos,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
int block_size,
|
||||
@@ -228,11 +233,7 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
|
||||
bool merged_qkv) {
|
||||
|
||||
const auto dtype = q.dtype();
|
||||
auto out_shape = q.shape();
|
||||
if (merged_qkv) {
|
||||
out_shape[1] -= 2 * num_kv_heads;
|
||||
}
|
||||
auto out = paddle::empty(out_shape, dtype, q.place());
|
||||
auto out = paddle::empty({q.shape()[0], num_heads * head_dim}, dtype, q.place());
|
||||
|
||||
switch (dtype) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
@@ -246,6 +247,8 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
|
||||
v,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_size,
|
||||
@@ -270,6 +273,8 @@ std::vector<paddle::Tensor> PagedAttn(const paddle::Tensor& q,
|
||||
v,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_size,
|
||||
@@ -299,6 +304,8 @@ std::vector<std::vector<int64_t>> PagedAttnInferShape(const std::vector<int64_t>
|
||||
const std::vector<int64_t>& v_shape,
|
||||
const std::vector<int64_t>& rope_sin_shape,
|
||||
const std::vector<int64_t>& rope_cos_shape,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
int block_size,
|
||||
@@ -311,36 +318,13 @@ std::vector<std::vector<int64_t>> PagedAttnInferShape(const std::vector<int64_t>
|
||||
bool use_sqrt_alibi,
|
||||
bool merged_qkv) {
|
||||
if (merged_qkv) {
|
||||
int64_t num_tokens = q_shape[0];
|
||||
int64_t num_heads = q_shape[1] - 2 * num_kv_heads;
|
||||
int64_t head_dim = q_shape[2];
|
||||
return {{num_tokens, num_heads, head_dim}};
|
||||
return {{q_shape[0], num_heads * head_dim}};
|
||||
} else {
|
||||
return {q_shape};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> PagedAttnInferDtype(const paddle::DataType& q_dtype,
|
||||
const paddle::DataType& k_cache_dtype,
|
||||
const paddle::DataType& v_cache_dtype,
|
||||
const paddle::DataType& block_table_dtype,
|
||||
const paddle::DataType& seq_lens_dtype,
|
||||
const paddle::DataType& alibi_slopes_dtype,
|
||||
const paddle::DataType& k_dtype,
|
||||
const paddle::DataType& v_dtype,
|
||||
const paddle::DataType& rope_sin_dtype,
|
||||
const paddle::DataType& rope_cos_dtype,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
bool causal,
|
||||
int window_left,
|
||||
int window_right,
|
||||
float softcap,
|
||||
bool enable_cuda_graph,
|
||||
bool use_sqrt_alibi,
|
||||
bool merged_qkv) {
|
||||
std::vector<paddle::DataType> PagedAttnInferDtype(const paddle::DataType& q_dtype) {
|
||||
return {q_dtype};
|
||||
}
|
||||
|
||||
@@ -351,7 +335,9 @@ PD_BUILD_STATIC_OP(paged_attn)
|
||||
paddle::Optional("v"), paddle::Optional("rope_sin"),
|
||||
paddle::Optional("rope_cos")})
|
||||
.Outputs({"out"})
|
||||
.Attrs({"num_kv_heads:int",
|
||||
.Attrs({"num_heads:int",
|
||||
"head_dim:int",
|
||||
"num_kv_heads:int",
|
||||
"scale:float",
|
||||
"block_size:int",
|
||||
"max_context_len:int",
|
||||
|
378
custom_ops/iluvatar_ops/prefill_fused_attn.cu
Normal file
378
custom_ops/iluvatar_ops/prefill_fused_attn.cu
Normal file
@@ -0,0 +1,378 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "iluvatar_context.h"
|
||||
|
||||
template <paddle::DataType T>
|
||||
void PrefillFusedPagedAttnKernel(const paddle::Tensor& qkv,
|
||||
paddle::Tensor& k_cache,
|
||||
paddle::Tensor& v_cache,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& cu_seqlens_qkv,
|
||||
const paddle::optional<paddle::Tensor> &rope_sin,
|
||||
const paddle::optional<paddle::Tensor> &rope_cos,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
float scale,
|
||||
bool causal,
|
||||
bool q_rope,
|
||||
bool k_rope,
|
||||
bool v_rope,
|
||||
paddle::Tensor& out) {
|
||||
|
||||
// check dtype and contiguous
|
||||
const auto& dtype = qkv.dtype();
|
||||
cuinferDataType_t data_type;
|
||||
if (dtype == paddle::DataType::FLOAT16) {
|
||||
data_type = CUINFER_DATA_HALF;
|
||||
|
||||
} else if (dtype == paddle::DataType::BFLOAT16) {
|
||||
data_type = CUINFER_DATA_BFLOAT16;
|
||||
} else {
|
||||
common::errors::InvalidArgument("paged_attention support half and bfloat16 now");
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(k_cache.dtype(),
|
||||
dtype,
|
||||
common::errors::InvalidArgument(
|
||||
"k_cache dtype must be the same as query dtype"));
|
||||
PADDLE_ENFORCE_EQ(k_cache.is_contiguous(),
|
||||
true,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects k_cache is contiguous"));
|
||||
PADDLE_ENFORCE_EQ(block_table.dtype(),
|
||||
paddle::DataType::INT32,
|
||||
common::errors::InvalidArgument(
|
||||
"block_table dtype must be int32"));
|
||||
PADDLE_ENFORCE_EQ(block_table.is_contiguous(),
|
||||
true,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects block_table is contiguous"));
|
||||
PADDLE_ENFORCE_EQ(cu_seqlens_qkv.dtype(),
|
||||
paddle::DataType::INT32,
|
||||
common::errors::InvalidArgument(
|
||||
"cu_seqlens_qkv dtype must be int32"));
|
||||
PADDLE_ENFORCE_EQ(cu_seqlens_qkv.is_contiguous(),
|
||||
true,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attention expects cu_seqlens_qkv is contiguous"));
|
||||
// check dim and shape
|
||||
// k_cache: [num_blocks, kv_num_heads, block_size, head_dim]
|
||||
// v_cache: [num_blocks, kv_num_heads, block_size, head_dim]
|
||||
// block_table: [batch_size, max_num_blocks_per_seq]
|
||||
// seq_lens: [batch_size]
|
||||
// qkv: [num_tokens, (num_heads+2*num_kv_heads)*head_dim]
|
||||
// out: [num_tokens, hidden_size]
|
||||
|
||||
const auto& qkv_dims = qkv.dims();
|
||||
PADDLE_ENFORCE_EQ(qkv_dims.size(),
|
||||
2,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive query dims is "
|
||||
"[num_tokens, (num_heads+2*num_kv_heads)*head_dim]"));
|
||||
PADDLE_ENFORCE_EQ(out.dims().size(),
|
||||
2,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive out dims is "
|
||||
"[num_tokens, hidden_size]"));
|
||||
|
||||
const auto& kv_cache_dims = k_cache.dims();
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims.size(),
|
||||
4,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive kv cache dims is "
|
||||
"[num_blocks, kv_num_heads, block_size, head_dim]"));
|
||||
|
||||
const auto& block_table_dims = block_table.dims();
|
||||
PADDLE_ENFORCE_EQ(block_table_dims.size(),
|
||||
2,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive block_table dims is "
|
||||
"[batch_size, max_num_blocks_per_seq]"));
|
||||
|
||||
const auto& cu_seqlens_qkv_dims = cu_seqlens_qkv.dims();
|
||||
PADDLE_ENFORCE_EQ(cu_seqlens_qkv_dims.size(),
|
||||
1,
|
||||
common::errors::InvalidArgument(
|
||||
"paged_attn receive cu_seqlens_qkv dims is [batch_size]"));
|
||||
|
||||
int batch_size = block_table_dims[0];
|
||||
int num_tokens = qkv_dims[0];
|
||||
int num_total_heads = num_heads + 2 * num_kv_heads;
|
||||
int qkv_stride = qkv.strides()[0];
|
||||
int num_blocks = kv_cache_dims[0];
|
||||
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims[1],
|
||||
num_kv_heads,
|
||||
common::errors::InvalidArgument(
|
||||
"kv_cache_dims[1] must be equal to num_kv_head"));
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims[2],
|
||||
block_size,
|
||||
common::errors::InvalidArgument(
|
||||
"kv_cache_dims[2] must be equal to block_size"));
|
||||
PADDLE_ENFORCE_EQ(kv_cache_dims[3],
|
||||
head_dim,
|
||||
common::errors::InvalidArgument(
|
||||
"kv_cache_dims[3] must be equal to head_dim"));
|
||||
PADDLE_ENFORCE_EQ(cu_seqlens_qkv_dims[0],
|
||||
batch_size + 1,
|
||||
common::errors::InvalidArgument(
|
||||
"cu_seqlens_qkv_dims[0] must be equal to batch_size + 1"));
|
||||
|
||||
int block_table_stride = block_table.strides()[0];
|
||||
const float *rope_sin_ptr = rope_sin ? rope_sin.get().data<float>() : nullptr;
|
||||
const float *rope_cos_ptr = rope_cos ? rope_cos.get().data<float>() : nullptr;
|
||||
|
||||
cuinferHandle_t cuinfer_handle = iluvatar::getContextInstance()->getIxInferHandle();
|
||||
|
||||
size_t workspace_size = 0;
|
||||
CUINFER_CHECK(cuinferGetFmhaFwdMergedFuseRopeWorkspaceSize(num_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
data_type,
|
||||
data_type,
|
||||
data_type,
|
||||
&workspace_size));
|
||||
auto* allocator = paddle::GetAllocator(qkv.place());
|
||||
phi::Allocator::AllocationPtr tmp_workspace = allocator->Allocate(workspace_size);
|
||||
void* workspace_ptr = tmp_workspace->ptr();
|
||||
|
||||
cuinferTensorDescriptor_t qkv_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
qkv_desc,
|
||||
data_type,
|
||||
3,
|
||||
std::vector<int>({num_tokens, num_total_heads, head_dim}).data(),
|
||||
std::vector<int>({num_total_heads * head_dim, head_dim, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t qkv_seqlens_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&qkv_seqlens_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
qkv_seqlens_desc,
|
||||
CUINFER_DATA_INT32,
|
||||
1,
|
||||
std::vector<int>({batch_size + 1}).data(),
|
||||
std::vector<int>({1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t block_table_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&block_table_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
block_table_desc,
|
||||
CUINFER_DATA_INT32,
|
||||
2,
|
||||
std::vector<int>({batch_size, block_table_stride}).data(),
|
||||
std::vector<int>({block_table_stride, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t o_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&o_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
o_desc,
|
||||
data_type,
|
||||
3,
|
||||
std::vector<int>({num_tokens, num_heads, head_dim}).data(),
|
||||
std::vector<int>({num_heads * head_dim, head_dim, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t k_cache_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&k_cache_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
k_cache_desc,
|
||||
data_type,
|
||||
4,
|
||||
std::vector<int>({num_blocks, num_kv_heads, block_size, head_dim}).data(),
|
||||
std::vector<int>({num_kv_heads * block_size * head_dim, block_size * head_dim, head_dim, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t v_cache_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&v_cache_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
v_cache_desc,
|
||||
data_type,
|
||||
4,
|
||||
std::vector<int>({num_blocks, num_kv_heads, block_size, head_dim}).data(),
|
||||
std::vector<int>({num_kv_heads * block_size * head_dim, block_size * head_dim, head_dim, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t cos_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&cos_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
cos_desc,
|
||||
CUINFER_DATA_FLOAT,
|
||||
2,
|
||||
std::vector<int>({max_seq_len, head_dim}).data(),
|
||||
std::vector<int>({head_dim, 1}).data()));
|
||||
|
||||
cuinferTensorDescriptor_t sin_desc;
|
||||
CUINFER_CHECK(cuinferCreateTensorDescriptor(&sin_desc));
|
||||
CUINFER_CHECK(cuinferSetTensorNdDescriptor(
|
||||
sin_desc,
|
||||
CUINFER_DATA_FLOAT,
|
||||
2,
|
||||
std::vector<int>({max_seq_len, head_dim}).data(),
|
||||
std::vector<int>({head_dim, 1}).data()));
|
||||
|
||||
CUINFER_CHECK(cuinferFmhaFwdMergedFuseRopeFunc(cuinfer_handle,
|
||||
qkv_desc,
|
||||
qkv.data(),
|
||||
qkv_seqlens_desc,
|
||||
cu_seqlens_qkv.data<int32_t>(),
|
||||
block_table_desc,
|
||||
block_table.data<int32_t>(),
|
||||
o_desc,
|
||||
out.data(),
|
||||
k_cache_desc,
|
||||
k_cache.data(),
|
||||
v_cache_desc,
|
||||
v_cache.data(),
|
||||
workspace_ptr,
|
||||
workspace_size,
|
||||
cos_desc,
|
||||
rope_cos_ptr,
|
||||
sin_desc,
|
||||
rope_sin_ptr,
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
causal,
|
||||
scale,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope));
|
||||
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(qkv_seqlens_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(block_table_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(o_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(k_cache_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(v_cache_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(cos_desc));
|
||||
CUINFER_CHECK(cuinferDestroyTensorDescriptor(sin_desc));
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> PrefillFusedPagedAttn(const paddle::Tensor& qkv,
|
||||
paddle::Tensor& k_cache,
|
||||
paddle::Tensor& v_cache,
|
||||
const paddle::Tensor& block_table,
|
||||
const paddle::Tensor& cu_seqlens_qkv,
|
||||
const paddle::optional<paddle::Tensor> &rope_sin,
|
||||
const paddle::optional<paddle::Tensor> &rope_cos,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
float scale,
|
||||
bool causal,
|
||||
bool q_rope,
|
||||
bool k_rope,
|
||||
bool v_rope) {
|
||||
|
||||
const auto dtype = qkv.dtype();
|
||||
auto out = paddle::empty({qkv.shape()[0], num_heads * head_dim}, dtype, qkv.place());
|
||||
|
||||
switch (dtype) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
PrefillFusedPagedAttnKernel<paddle::DataType::BFLOAT16>(qkv,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
cu_seqlens_qkv,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
scale,
|
||||
causal,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
out);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
PrefillFusedPagedAttnKernel<paddle::DataType::FLOAT16>(qkv,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
cu_seqlens_qkv,
|
||||
rope_sin,
|
||||
rope_cos,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
scale,
|
||||
causal,
|
||||
q_rope,
|
||||
k_rope,
|
||||
v_rope,
|
||||
out);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for Paged attn");
|
||||
}
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> PrefillFusedPagedAttnInferShape(const std::vector<int64_t>& qkv_shape,
|
||||
const std::vector<int64_t>& k_cache_shape,
|
||||
const std::vector<int64_t>& v_cache_shape,
|
||||
const std::vector<int64_t>& block_table_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_qkv_shape,
|
||||
const std::vector<int64_t>& rope_sin_shape,
|
||||
const std::vector<int64_t>& rope_cos_shape,
|
||||
int num_heads,
|
||||
int head_dim,
|
||||
int num_kv_heads,
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
float scale,
|
||||
bool causal,
|
||||
bool q_rope,
|
||||
bool k_rope,
|
||||
bool v_rope) {
|
||||
return {{qkv_shape[0], num_heads * head_dim}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> PrefillFusedPagedAttnInferDtype(const paddle::DataType& qkv_dtype) {
|
||||
return {qkv_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(prefill_fused_paged_attn)
|
||||
.Inputs({"qkv", "k_cache", "v_cache", "block_table", "cu_seqlens_qkv",
|
||||
paddle::Optional("rope_sin"), paddle::Optional("rope_cos")})
|
||||
.Outputs({"out"})
|
||||
.Attrs({"num_heads:int",
|
||||
"head_dim:int",
|
||||
"num_kv_heads:int",
|
||||
"block_size:int",
|
||||
"max_seq_len:int",
|
||||
"scale:float",
|
||||
"causal:bool",
|
||||
"q_rope:bool",
|
||||
"k_rope:bool",
|
||||
"v_rope:bool"})
|
||||
.SetKernelFn(PD_KERNEL(PrefillFusedPagedAttn))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(PrefillFusedPagedAttnInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(PrefillFusedPagedAttnInferDtype));
|
181
custom_ops/metax_ops/fused_moe.cu
Normal file
181
custom_ops/metax_ops/fused_moe.cu
Normal file
@@ -0,0 +1,181 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
#include "mc_fused_moe_helper.h"
|
||||
#include "fused_moe_op.h"
|
||||
|
||||
__global__ void compute_total_rows_before_expert_kernel(
|
||||
int* sorted_experts,
|
||||
const int64_t sorted_experts_len,
|
||||
const int64_t num_experts,
|
||||
int32_t* total_rows_before_expert) {
|
||||
const int expert = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (expert >= num_experts) return;
|
||||
|
||||
total_rows_before_expert[expert] =
|
||||
find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert);
|
||||
}
|
||||
|
||||
void compute_total_rows_before_expert(int* sorted_indices,
|
||||
const int64_t total_indices,
|
||||
const int64_t num_experts,
|
||||
int32_t* total_rows_before_expert,
|
||||
cudaStream_t stream) {
|
||||
const int threads = std::min(int64_t(1024), num_experts);
|
||||
const int blocks = (num_experts + threads - 1) / threads;
|
||||
|
||||
compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>(
|
||||
sorted_indices, total_indices, num_experts, total_rows_before_expert);
|
||||
}
|
||||
|
||||
template <paddle::DataType T, typename ElementA, typename ElementB, typename ElementC>
|
||||
void FusedMoeKernel(const paddle::Tensor& input,
|
||||
const paddle::Tensor& gate_weight,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const std::string& quant_method,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool norm_topk_prob,
|
||||
paddle::Tensor* output) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto* output_data = output->data<data_t>();
|
||||
|
||||
auto moe_compute = McMoeHelper<data_t, ElementA, ElementB, ElementC>(quant_method);
|
||||
|
||||
moe_compute.computeFFN(
|
||||
&input,
|
||||
&gate_weight,
|
||||
&ffn1_weight,
|
||||
ffn1_scale ? ffn1_scale.get_ptr() : nullptr,
|
||||
ffn1_bias ? ffn1_bias.get_ptr() : nullptr,
|
||||
&ffn2_weight,
|
||||
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
|
||||
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
|
||||
nullptr,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
1.0, // ComputeFFN
|
||||
"ffn",
|
||||
output);
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> FusedExpertMoe(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& gate_weight,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const std::string& quant_method,
|
||||
const int moe_topk,
|
||||
const bool norm_topk_prob,
|
||||
const bool group_moe) {
|
||||
const auto input_type = input.dtype();
|
||||
auto output = paddle::empty_like(input);
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
FusedMoeKernel<paddle::DataType::BFLOAT16, maca_bfloat16, int8_t, maca_bfloat16>(input,
|
||||
gate_weight,
|
||||
ffn1_weight,
|
||||
ffn1_scale,
|
||||
ffn1_bias,
|
||||
ffn2_weight,
|
||||
ffn2_scale,
|
||||
ffn2_bias,
|
||||
quant_method,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
&output);
|
||||
break;
|
||||
// case paddle::DataType::FLOAT16:
|
||||
// FusedMoeKernel<paddle::DataType::FLOAT16>(input,
|
||||
// gate_weight,
|
||||
// ffn1_weight,
|
||||
// ffn1_scale,
|
||||
// ffn1_bias,
|
||||
// ffn2_weight,
|
||||
// ffn2_scale,
|
||||
// ffn2_bias,
|
||||
// quant_method,
|
||||
// moe_topk,
|
||||
// group_moe,
|
||||
// norm_topk_prob,
|
||||
// &output);
|
||||
// break;
|
||||
default:
|
||||
PD_THROW("Only support bf16 for FusedMoeKernel");
|
||||
}
|
||||
return {output};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> FusedExpertMoeInferShape(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::vector<int64_t>& gate_weight_shape,
|
||||
const std::vector<int64_t>& ffn1_weight_shape,
|
||||
const std::vector<int64_t>& ffn2_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape) {
|
||||
return {input_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> FusedExpertMoeInferDtype(
|
||||
const paddle::DataType& input_dtype,
|
||||
const paddle::DataType& gate_weight_dtype,
|
||||
const paddle::DataType& ffn1_weight_dtype,
|
||||
const paddle::DataType& ffn2_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn1_scale_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_scale_dtype) {
|
||||
return {input_dtype};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(fused_expert_moe)
|
||||
.Inputs({"input",
|
||||
"gate_weight",
|
||||
"ffn1_weight",
|
||||
"ffn2_weight",
|
||||
paddle::Optional("ffn1_bias"),
|
||||
paddle::Optional("ffn1_scale"),
|
||||
paddle::Optional("ffn2_bias"),
|
||||
paddle::Optional("ffn2_scale")})
|
||||
.Outputs({"output"})
|
||||
.Attrs({"quant_method:std::string",
|
||||
"moe_topk:int",
|
||||
"norm_topk_prob:bool",
|
||||
"group_moe:bool"})
|
||||
.SetKernelFn(PD_KERNEL(FusedExpertMoe))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(FusedExpertMoeInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(FusedExpertMoeInferDtype));
|
53
custom_ops/metax_ops/fused_moe_helper.h
Normal file
53
custom_ops/metax_ops/fused_moe_helper.h
Normal file
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
|
||||
#include "fused_moe_op.h"
|
||||
|
||||
using namespace phi;
|
||||
|
||||
template <typename T, int VecSize>
|
||||
__global__ void moe_token_type_ids_kernel(T *gating_output,
|
||||
const int *moe_token_type_ids_out,
|
||||
const int num_rows,
|
||||
const int num_experts,
|
||||
const int k) {
|
||||
const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (moe_token_index >= num_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
gating_output[moe_token_index * 2] =
|
||||
gating_output[moe_token_index * 2] +
|
||||
(moe_token_type_ids_out[moe_token_index]) * -1e10;
|
||||
gating_output[moe_token_index * 2 + 1] =
|
||||
gating_output[moe_token_index * 2 + 1] +
|
||||
(1 - moe_token_type_ids_out[moe_token_index]) * -1e10;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void moe_token_type_ids_kernelLauncher(T *gating_output,
|
||||
const int *moe_token_type_ids_out,
|
||||
const int num_rows,
|
||||
const int num_experts,
|
||||
const int k,
|
||||
cudaStream_t stream) {
|
||||
const int blocks = num_rows * k / 512 + 1;
|
||||
const int threads = 512;
|
||||
moe_token_type_ids_kernel<T, 1><<<blocks, 512, 0, stream>>>(
|
||||
gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
|
||||
}
|
123
custom_ops/metax_ops/fused_moe_imp_op.h
Normal file
123
custom_ops/metax_ops/fused_moe_imp_op.h
Normal file
@@ -0,0 +1,123 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include "cub/cub.cuh"
|
||||
|
||||
static const float HALF_FLT_MAX = 65504.F;
|
||||
static const float HALF_FLT_MIN = -65504.F;
|
||||
static inline size_t AlignTo16(const size_t& input) {
|
||||
static constexpr int ALIGNMENT = 16;
|
||||
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
|
||||
}
|
||||
|
||||
class CubKeyValueSorter {
|
||||
public:
|
||||
CubKeyValueSorter() : num_experts_(0), num_bits_(sizeof(int) * 8) {}
|
||||
|
||||
explicit CubKeyValueSorter(const int num_experts)
|
||||
: num_experts_(num_experts),
|
||||
num_bits_(static_cast<int>(log2(num_experts)) + 1) {}
|
||||
|
||||
void update_num_experts(const int num_experts) {
|
||||
num_experts_ = num_experts;
|
||||
num_bits_ = static_cast<int>(log2(num_experts)) + 1;
|
||||
}
|
||||
|
||||
size_t getWorkspaceSize(const size_t num_key_value_pairs,
|
||||
bool descending = false) {
|
||||
num_key_value_pairs_ = num_key_value_pairs;
|
||||
size_t required_storage = 0;
|
||||
int* null_int = nullptr;
|
||||
if (descending) {
|
||||
cub::DeviceRadixSort::SortPairsDescending(NULL,
|
||||
required_storage,
|
||||
null_int,
|
||||
null_int,
|
||||
null_int,
|
||||
null_int,
|
||||
num_key_value_pairs,
|
||||
0,
|
||||
32);
|
||||
} else {
|
||||
cub::DeviceRadixSort::SortPairs(NULL,
|
||||
required_storage,
|
||||
null_int,
|
||||
null_int,
|
||||
null_int,
|
||||
null_int,
|
||||
num_key_value_pairs,
|
||||
0,
|
||||
num_bits_);
|
||||
}
|
||||
return required_storage;
|
||||
}
|
||||
|
||||
template <typename KeyT>
|
||||
void run(void* workspace,
|
||||
const size_t workspace_size,
|
||||
const KeyT* keys_in,
|
||||
KeyT* keys_out,
|
||||
const int* values_in,
|
||||
int* values_out,
|
||||
const size_t num_key_value_pairs,
|
||||
bool descending,
|
||||
cudaStream_t stream) {
|
||||
size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs);
|
||||
size_t actual_ws_size = workspace_size;
|
||||
|
||||
if (expected_ws_size > workspace_size) {
|
||||
std::stringstream err_ss;
|
||||
err_ss << "[Error][CubKeyValueSorter::run]\n";
|
||||
err_ss << "Error. The allocated workspace is too small to run this "
|
||||
"problem.\n";
|
||||
err_ss << "Expected workspace size of at least " << expected_ws_size
|
||||
<< " but got problem size " << workspace_size << "\n";
|
||||
throw std::runtime_error(err_ss.str());
|
||||
}
|
||||
if (descending) {
|
||||
cub::DeviceRadixSort::SortPairsDescending(workspace,
|
||||
actual_ws_size,
|
||||
keys_in,
|
||||
keys_out,
|
||||
values_in,
|
||||
values_out,
|
||||
num_key_value_pairs,
|
||||
0,
|
||||
32,
|
||||
stream);
|
||||
} else {
|
||||
cub::DeviceRadixSort::SortPairs(workspace,
|
||||
actual_ws_size,
|
||||
keys_in,
|
||||
keys_out,
|
||||
values_in,
|
||||
values_out,
|
||||
num_key_value_pairs,
|
||||
0,
|
||||
num_bits_,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
size_t num_key_value_pairs_;
|
||||
int num_experts_;
|
||||
int num_bits_;
|
||||
};
|
990
custom_ops/metax_ops/fused_moe_op.h
Normal file
990
custom_ops/metax_ops/fused_moe_op.h
Normal file
@@ -0,0 +1,990 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include "fused_moe_imp_op.h"
|
||||
#include "fused_moe_helper.h"
|
||||
#include "mctlass/numeric_conversion.h" // BUILD_MARK
|
||||
// Ignore mctlass warnings about type punning
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#pragma GCC diagnostic ignored "-Wunused-function"
|
||||
|
||||
// #include "paddle/phi/backends/gpu/gpu_info.h"
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
#define WARP_SIZE 32
|
||||
|
||||
struct GpuLaunchConfig {
|
||||
dim3 block_per_grid;
|
||||
dim3 thread_per_block;
|
||||
};
|
||||
|
||||
inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) {
|
||||
int blocks_x = cols;
|
||||
int blocks_y = 1;
|
||||
int blocks_z = 1;
|
||||
if (blocks_x > 1024) {
|
||||
blocks_y = 256;
|
||||
blocks_x = (blocks_x + blocks_y - 1) / blocks_y;
|
||||
}
|
||||
|
||||
GpuLaunchConfig config;
|
||||
config.block_per_grid.x = blocks_x;
|
||||
config.block_per_grid.y = blocks_y;
|
||||
config.block_per_grid.z = blocks_z;
|
||||
return config;
|
||||
}
|
||||
|
||||
// ====================== Softmax things ===============================
|
||||
// We have our own implementation of softmax here so we can support transposing
|
||||
// the output in the softmax kernel when we extend this module to support
|
||||
// expert-choice routing.
|
||||
template <typename T, int TPB>
|
||||
__launch_bounds__(TPB) __global__
|
||||
void group_moe_softmax(const T* input,
|
||||
T* output,
|
||||
T* softmax_max_prob,
|
||||
const int64_t num_cols,
|
||||
const int64_t softmax_num_rows) {
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
__shared__ float normalizing_factor;
|
||||
__shared__ float float_max;
|
||||
__shared__ float max_out;
|
||||
|
||||
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (globalIdx >= softmax_num_rows) {
|
||||
return;
|
||||
}
|
||||
const int64_t thread_row_offset = globalIdx * num_cols;
|
||||
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0) {
|
||||
float_max = maxElem;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
threadData = 0;
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
||||
}
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
threadData = 0;
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val =
|
||||
exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
||||
output[idx] = T(val);
|
||||
threadData = max(static_cast<float>(T(val)), threadData);
|
||||
}
|
||||
|
||||
const float maxOut = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0) {
|
||||
// group max probs
|
||||
max_out = 1.f / maxOut;
|
||||
softmax_max_prob[globalIdx] = T(max_out);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
// group softmax normalization
|
||||
output[idx] = output[idx] * static_cast<T>(max_out);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
T* output,
|
||||
int* indices,
|
||||
int* source_rows,
|
||||
T* softmax_max_prob,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
cub_kvp thread_kvp;
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
const int block_row = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (block_row >= num_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const bool should_process_row = true;
|
||||
const int thread_read_offset = block_row * num_experts;
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
|
||||
cub_kvp inp_kvp;
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
|
||||
const int idx = thread_read_offset + expert;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = inputs_after_softmax[idx];
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||
const int prior_winning_expert = indices[k * block_row + prior_k];
|
||||
|
||||
if (prior_winning_expert == expert) {
|
||||
inp_kvp = thread_kvp;
|
||||
}
|
||||
}
|
||||
|
||||
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||
}
|
||||
|
||||
const cub_kvp result_kvp =
|
||||
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0) {
|
||||
const int idx = k * block_row + k_idx;
|
||||
// restore normalized probes
|
||||
output[idx] = result_kvp.value / T(softmax_max_prob[idx]);
|
||||
indices[idx] = should_process_row ? result_kvp.key : num_experts;
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
|
||||
T* output,
|
||||
const int64_t num_cols,
|
||||
const int64_t num_rows) {
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
__shared__ float normalizing_factor;
|
||||
__shared__ float float_max;
|
||||
|
||||
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (globalIdx >= num_rows) {
|
||||
return;
|
||||
}
|
||||
const int64_t thread_row_offset = globalIdx * num_cols;
|
||||
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0) {
|
||||
float_max = maxElem;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
threadData = 0;
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
||||
}
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val =
|
||||
exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
||||
output[idx] = T(val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
T* output,
|
||||
int* indices,
|
||||
int* source_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
cub_kvp thread_kvp;
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
const int block_row = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (block_row >= num_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const bool should_process_row = true;
|
||||
const int thread_read_offset = block_row * num_experts;
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
|
||||
cub_kvp inp_kvp;
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
|
||||
const int idx = thread_read_offset + expert;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = inputs_after_softmax[idx];
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||
const int prior_winning_expert = indices[k * block_row + prior_k];
|
||||
|
||||
if (prior_winning_expert == expert) {
|
||||
inp_kvp = thread_kvp;
|
||||
}
|
||||
}
|
||||
|
||||
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||
}
|
||||
|
||||
const cub_kvp result_kvp =
|
||||
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0) {
|
||||
const int idx = k * block_row + k_idx;
|
||||
output[idx] = result_kvp.value;
|
||||
indices[idx] = should_process_row ? result_kvp.key : num_experts;
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// ====================== TopK softmax things ===============================
|
||||
|
||||
/*
|
||||
A Top-K gating softmax written to exploit when the number of experts in the
|
||||
MoE layers are a small power of 2. This allows us to cleanly share the rows
|
||||
among the threads in a single warp and eliminate communication between warps
|
||||
(so no need to use shared mem).
|
||||
|
||||
It fuses the softmax, max and argmax into a single kernel.
|
||||
|
||||
Limitations:
|
||||
1) This implementation is intended for when the number of experts is a small
|
||||
power of 2. 2) This implementation assumes k is small, but will work for any
|
||||
k.
|
||||
*/
|
||||
|
||||
template <typename T,
|
||||
int VPT,
|
||||
int NUM_EXPERTS,
|
||||
int WARPS_PER_CTA,
|
||||
int BYTES_PER_LDG>
|
||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
void topk_gating_softmax(const T* input,
|
||||
T* output,
|
||||
const int64_t num_rows,
|
||||
int* indices,
|
||||
int* source_rows,
|
||||
const int64_t k) {
|
||||
// We begin by enforcing compile time assertions and setting up compile time
|
||||
// constants.
|
||||
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
|
||||
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS),
|
||||
"NUM_EXPERTS must be power of 2");
|
||||
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG),
|
||||
"BYTES_PER_LDG must be power of 2");
|
||||
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
||||
|
||||
// Number of bytes each thread pulls in per load
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
|
||||
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
|
||||
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
|
||||
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
|
||||
|
||||
// Restrictions based on previous section.
|
||||
static_assert(
|
||||
VPT % ELTS_PER_LDG == 0,
|
||||
"The elements per thread must be a multiple of the elements per ldg");
|
||||
static_assert(WARP_SIZE % THREADS_PER_ROW == 0,
|
||||
"The threads per row must cleanly divide the threads per warp");
|
||||
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW),
|
||||
"THREADS_PER_ROW must be power of 2");
|
||||
static_assert(THREADS_PER_ROW <= WARP_SIZE,
|
||||
"THREADS_PER_ROW can be at most warp size");
|
||||
|
||||
// We have NUM_EXPERTS elements per row. We specialize for small #experts
|
||||
static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
|
||||
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
|
||||
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
|
||||
|
||||
// Restrictions for previous section.
|
||||
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0,
|
||||
"The elts per row must cleanly divide the total elt per warp");
|
||||
|
||||
// ===================== From this point, we finally start computing run-time
|
||||
// variables. ========================
|
||||
|
||||
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a
|
||||
// block contains WARPS_PER_CTA warps. This, each block processes a chunk of
|
||||
// rows. We start by computing the start row for each block.
|
||||
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
|
||||
|
||||
// Now, using the base row per thread block, we compute the base row per warp.
|
||||
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
|
||||
|
||||
// The threads in a warp are split into sub-groups that will work on a row.
|
||||
// We compute row offset for each thread sub-group
|
||||
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
|
||||
const int thread_row = warp_base_row + thread_row_in_warp;
|
||||
|
||||
// Threads with indices out of bounds should early exit here.
|
||||
if (thread_row >= num_rows) return;
|
||||
const bool should_process_row = true;
|
||||
|
||||
// We finally start setting up the read pointers for each thread. First, each
|
||||
// thread jumps to the start of the row it will read.
|
||||
const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
||||
|
||||
// Now, we compute the group each thread belong to in order to determine the
|
||||
// first column to start loads.
|
||||
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
|
||||
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
|
||||
const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
||||
|
||||
// Determine the pointer type to use to read in the data depending on the
|
||||
// BYTES_PER_LDG template param. In theory, this can support all powers of 2
|
||||
// up to 16.
|
||||
using AccessType = mctlass::AlignedArray<T, ELTS_PER_LDG>;
|
||||
|
||||
// Finally, we pull in the data from global mem
|
||||
mctlass::Array<T, VPT> row_chunk_input;
|
||||
AccessType* row_chunk_vec_ptr =
|
||||
reinterpret_cast<AccessType*>(&row_chunk_input);
|
||||
const AccessType* vec_thread_read_ptr =
|
||||
reinterpret_cast<const AccessType*>(thread_read_ptr);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
||||
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
||||
}
|
||||
|
||||
using ComputeType = float;
|
||||
using Converter = mctlass::NumericArrayConverter<ComputeType, T, VPT>;
|
||||
Converter compute_type_converter;
|
||||
mctlass::Array<ComputeType, VPT> row_chunk =
|
||||
compute_type_converter(row_chunk_input);
|
||||
|
||||
// First, we perform a max reduce within the thread. We can do the max in fp16
|
||||
// safely (I think) and just convert to float afterwards for the exp + sum
|
||||
// reduction.
|
||||
ComputeType thread_max = row_chunk[0];
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < VPT; ++ii) {
|
||||
thread_max = max(thread_max, row_chunk[ii]);
|
||||
}
|
||||
|
||||
// Now, we find the max within the thread group and distribute among the
|
||||
// threads. We use a butterfly reduce.
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
||||
thread_max =
|
||||
max(thread_max,
|
||||
__shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
|
||||
}
|
||||
|
||||
// From this point, thread max in all the threads have the max within the row.
|
||||
// Now, we subtract the max from each element in the thread and take the exp.
|
||||
// We also compute the thread local sum.
|
||||
float row_sum = 0;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii) {
|
||||
row_chunk[ii] = expf(row_chunk[ii] - thread_max);
|
||||
row_sum += row_chunk[ii];
|
||||
}
|
||||
|
||||
// Now, we perform the sum reduce within each thread group. Similar to the max
|
||||
// reduce, we use a bufferfly pattern.
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
||||
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
|
||||
}
|
||||
|
||||
// From this point, all threads have the max and the sum for their rows in the
|
||||
// thread_max and thread_sum variables respectively. Finally, we can scale the
|
||||
// rows for the softmax. Technically, for top-k gating we don't need to
|
||||
// compute the entire softmax row. We can likely look at the maxes and only
|
||||
// compute for the top-k values in the row. However, this kernel will likely
|
||||
// not be a bottle neck and it seems better to closer match torch and find the
|
||||
// argmax after computing the softmax.
|
||||
const float reciprocal_row_sum = 1.f / row_sum;
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii) {
|
||||
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
|
||||
}
|
||||
|
||||
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find
|
||||
// the topk elements in each row, along with the max index.
|
||||
int start_col = first_elt_read_by_thread;
|
||||
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
// First, each thread does the local argmax
|
||||
float max_val = row_chunk[0];
|
||||
int expert = start_col;
|
||||
#pragma unroll
|
||||
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD;
|
||||
++ldg, col += COLS_PER_GROUP_LDG) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ELTS_PER_LDG; ++ii) {
|
||||
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
|
||||
|
||||
// No check on the experts here since columns with the smallest index
|
||||
// are processed first and only updated if > (not >=)
|
||||
if (val > max_val) {
|
||||
max_val = val;
|
||||
expert = col + ii;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now, we perform the argmax reduce. We use the butterfly pattern so threads
|
||||
// reach consensus about the max. This will be useful for K > 1 so that the
|
||||
// threads can agree on "who" had the max value. That thread can then blank out
|
||||
// their max with -inf and the warp can run more iterations...
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
|
||||
float other_max =
|
||||
__shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
|
||||
int other_expert =
|
||||
__shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
|
||||
|
||||
// We want lower indices to "win" in every thread so we break ties this
|
||||
// way
|
||||
if (other_max > max_val ||
|
||||
(other_max == max_val && other_expert < expert)) {
|
||||
max_val = other_max;
|
||||
expert = other_expert;
|
||||
}
|
||||
}
|
||||
|
||||
// Write the max for this k iteration to global memory.
|
||||
if (thread_group_idx == 0) {
|
||||
// The lead thread from each sub-group will write out the final results to
|
||||
// global memory. (This will be a single) thread per row of the
|
||||
// input/output matrices.
|
||||
const int idx = k * thread_row + k_idx;
|
||||
output[idx] = T(max_val);
|
||||
indices[idx] = should_process_row ? expert : NUM_EXPERTS;
|
||||
source_rows[idx] = k_idx * num_rows + thread_row;
|
||||
}
|
||||
|
||||
// Finally, we clear the value in the thread with the current max if there
|
||||
// is another iteration to run.
|
||||
if (k_idx + 1 < k) {
|
||||
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
|
||||
const int thread_to_clear_in_group =
|
||||
(expert / ELTS_PER_LDG) % THREADS_PER_ROW;
|
||||
|
||||
// Only the thread in the group which produced the max will reset the
|
||||
// "winning" value to -inf.
|
||||
if (thread_group_idx == thread_to_clear_in_group) {
|
||||
const int offset_for_expert = expert % ELTS_PER_LDG;
|
||||
// Safe to set to any negative value since row_chunk values must be
|
||||
// between 0 and 1.
|
||||
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] =
|
||||
ComputeType(-10000.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
// Constructs some constants needed to partition the work across threads at
|
||||
// compile time.
|
||||
template <typename T, int EXPERTS, int BYTES_PER_LDG>
|
||||
struct TopkConstants {
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
|
||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 ||
|
||||
EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0,
|
||||
"");
|
||||
static constexpr int VECs_PER_THREAD =
|
||||
std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
||||
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
||||
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
||||
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename T, int EXPERTS, int WARPS_PER_TB>
|
||||
void topk_gating_softmax_launcher_helper(const T* input,
|
||||
T* output,
|
||||
int* indices,
|
||||
int* source_row,
|
||||
const int64_t num_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
cudaStream_t stream) {
|
||||
static constexpr uint64_t MAX_BYTES_PER_LDG = 16;
|
||||
static constexpr int BYTES_PER_LDG =
|
||||
std::min(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS);
|
||||
using Constants = detail::TopkConstants<T, EXPERTS, BYTES_PER_LDG>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||
|
||||
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
|
||||
topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG>
|
||||
<<<num_blocks, block_dim, 0, stream>>>(
|
||||
input, output, num_rows, indices, source_row, k);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void topk_gating_softmax_kernelLauncher(const T* input,
|
||||
T* output,
|
||||
T* softmax,
|
||||
int* indices,
|
||||
int* source_row,
|
||||
T* softmax_max_prob,
|
||||
const int64_t num_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const bool group_moe,
|
||||
cudaStream_t stream,
|
||||
const bool topk_only_mode = false) {
|
||||
if (topk_only_mode) {
|
||||
static constexpr int TPB = 256;
|
||||
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
moe_top_k<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
input, output, indices, source_row, num_experts, k, num_rows);
|
||||
return;
|
||||
}
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
|
||||
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
|
||||
case N: { \
|
||||
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
|
||||
input, output, indices, source_row, num_rows, num_experts, k, stream); \
|
||||
break; \
|
||||
}
|
||||
switch (num_experts) {
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256)
|
||||
|
||||
default: {
|
||||
static constexpr int TPB = 256;
|
||||
if (group_moe) {
|
||||
const int group_experts = num_experts / k;
|
||||
const int softmax_num_rows = num_rows * k;
|
||||
const auto config_softmax = Get1DBlocksAnd2DGridsMoe(softmax_num_rows);
|
||||
group_moe_softmax<T, TPB>
|
||||
<<<config_softmax.block_per_grid, TPB, 0, stream>>>(
|
||||
input,
|
||||
softmax,
|
||||
softmax_max_prob,
|
||||
group_experts,
|
||||
softmax_num_rows);
|
||||
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
moe_top_k<T, TPB>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
|
||||
output,
|
||||
indices,
|
||||
source_row,
|
||||
softmax_max_prob,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows);
|
||||
} else {
|
||||
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
input, softmax, num_experts, num_rows);
|
||||
moe_top_k<T, TPB>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
|
||||
output,
|
||||
indices,
|
||||
source_row,
|
||||
num_experts,
|
||||
k,
|
||||
num_rows);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ========================== Permutation things
|
||||
// =======================================
|
||||
|
||||
// Duplicated and permutes rows for MoE. In addition, reverse the permutation
|
||||
// map to help with finalizing routing.
|
||||
|
||||
// "expanded_x_row" simply means that the number of values is num_rows x k. It
|
||||
// is "expanded" since we will have to duplicate some rows in the input matrix
|
||||
// to match the dimensions. Duplicates will always get routed to separate
|
||||
// experts in the end.
|
||||
|
||||
// Note that the expanded_dest_row_to_expanded_source_row map referred to here
|
||||
// has indices in the range (0, k*rows_in_input - 1). However, it is set up so
|
||||
// that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input all map
|
||||
// to row 0 in the original matrix. Thus, to know where to read in the source
|
||||
// matrix, we simply take the modulus of the expanded index.
|
||||
|
||||
template <typename T, int VecSize>
|
||||
__global__ void initialize_moe_routing_kernel(
|
||||
const T* unpermuted_input,
|
||||
T* permuted_output,
|
||||
const int* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row,
|
||||
const int64_t num_rows,
|
||||
const int64_t active_rows,
|
||||
const int64_t cols,
|
||||
const int64_t num_rows_k) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
LoadT src_vec;
|
||||
|
||||
// Reverse permutation map.
|
||||
// I do this so that later, we can use the source -> dest map to do the k-way
|
||||
// reduction and unpermuting. I need the reverse map for that reduction to
|
||||
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
|
||||
// thread block will be responsible for all k summations.
|
||||
const int expanded_dest_row = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (expanded_dest_row >= num_rows_k) return;
|
||||
const int expanded_source_row =
|
||||
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
|
||||
if (threadIdx.x == 0) {
|
||||
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
|
||||
expanded_dest_row;
|
||||
}
|
||||
|
||||
if ((blockIdx.x + blockIdx.y * gridDim.x) < active_rows) {
|
||||
// Duplicate and permute rows
|
||||
const int source_row = expanded_source_row % num_rows;
|
||||
|
||||
const T* source_row_ptr = unpermuted_input + source_row * cols;
|
||||
T* dest_row_ptr = permuted_output + expanded_dest_row * cols;
|
||||
|
||||
for (int tid = threadIdx.x * VecSize; tid < cols;
|
||||
tid += blockDim.x * VecSize) {
|
||||
// dest_row_ptr[tid] = source_row_ptr[tid];
|
||||
Load<T, VecSize>(&source_row_ptr[tid], &src_vec);
|
||||
Store<T, VecSize>(src_vec, &dest_row_ptr[tid]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void initialize_moe_routing_kernelLauncher(
|
||||
const T* unpermuted_input,
|
||||
T* permuted_output,
|
||||
const int* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row,
|
||||
const int64_t num_rows,
|
||||
const int64_t active_rows,
|
||||
const int64_t cols,
|
||||
const int64_t k,
|
||||
cudaStream_t stream) {
|
||||
const int threads = std::min(cols, int64_t(1024));
|
||||
constexpr int max_pack_size = 16 / sizeof(T);
|
||||
const auto config_initialize = Get1DBlocksAnd2DGridsMoe(num_rows * k);
|
||||
if (cols % max_pack_size == 0) {
|
||||
initialize_moe_routing_kernel<T, max_pack_size>
|
||||
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
|
||||
unpermuted_input,
|
||||
permuted_output,
|
||||
expanded_dest_row_to_expanded_source_row,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
num_rows,
|
||||
k * active_rows,
|
||||
cols,
|
||||
num_rows * k);
|
||||
} else {
|
||||
initialize_moe_routing_kernel<T, 1>
|
||||
<<<config_initialize.block_per_grid, threads, 0, stream>>>(
|
||||
unpermuted_input,
|
||||
permuted_output,
|
||||
expanded_dest_row_to_expanded_source_row,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
num_rows,
|
||||
k * active_rows,
|
||||
cols,
|
||||
num_rows * k);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================== Infer GEMM sizes
|
||||
// =================================
|
||||
__device__ inline int find_total_elts_leq_target(int* sorted_indices,
|
||||
const int64_t arr_length,
|
||||
const int64_t target) {
|
||||
int64_t low = 0, high = arr_length - 1, target_location = -1;
|
||||
while (low <= high) {
|
||||
int64_t mid = (low + high) / 2;
|
||||
|
||||
if (sorted_indices[mid] > target) {
|
||||
high = mid - 1;
|
||||
} else {
|
||||
low = mid + 1;
|
||||
target_location = mid;
|
||||
}
|
||||
}
|
||||
return target_location + 1;
|
||||
}
|
||||
|
||||
void compute_total_rows_before_expert(int* sorted_indices,
|
||||
const int64_t total_indices,
|
||||
const int64_t num_experts,
|
||||
int32_t* total_rows_before_expert,
|
||||
cudaStream_t stream);
|
||||
|
||||
// Final kernel to unpermute and scale
|
||||
// This kernel unpermutes the original data, does the k-way reduction and
|
||||
// performs the final skip connection.
|
||||
template <typename T, int RESIDUAL_NUM>
|
||||
__global__ void finalize_moe_routing_kernel(
|
||||
const T* expanded_permuted_rows,
|
||||
T* reduced_unpermuted_output,
|
||||
const T* bias,
|
||||
const float* scales,
|
||||
const int* expanded_source_row_to_expanded_dest_row,
|
||||
const int* expert_for_source_row,
|
||||
const int64_t cols,
|
||||
const int64_t k,
|
||||
const int64_t compute_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const int64_t num_rows) {
|
||||
const int original_row = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
// const int original_row = blockIdx.x;
|
||||
// const int num_rows = gridDim.x;
|
||||
if (original_row >= num_rows) return;
|
||||
T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols;
|
||||
|
||||
for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) {
|
||||
T thread_output{0.f};
|
||||
float row_rescale{0.f};
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
const int expanded_original_row = original_row + k_idx * num_rows;
|
||||
const int expanded_permuted_row =
|
||||
expanded_source_row_to_expanded_dest_row[expanded_original_row];
|
||||
|
||||
const int64_t k_offset = original_row * k + k_idx;
|
||||
const float row_scale = scales[k_offset];
|
||||
row_rescale = row_rescale + row_scale;
|
||||
|
||||
const T* expanded_permuted_rows_row_ptr =
|
||||
expanded_permuted_rows + expanded_permuted_row * cols;
|
||||
|
||||
const int expert_idx = expert_for_source_row[k_offset];
|
||||
const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr;
|
||||
const T bias_value = bias_ptr ? bias_ptr[tid] : T{0.f};
|
||||
|
||||
thread_output =
|
||||
static_cast<float>(thread_output) +
|
||||
row_scale * static_cast<float>(
|
||||
expanded_permuted_rows_row_ptr[tid] +
|
||||
bias_value *
|
||||
static_cast<T>(static_cast<float>(compute_bias)));
|
||||
}
|
||||
|
||||
thread_output = static_cast<float>(thread_output) /
|
||||
(norm_topk_prob ? row_rescale : 1.0f) *
|
||||
routed_scaling_factor;
|
||||
reduced_row_ptr[tid] = thread_output;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void finalize_moe_routing_kernelLauncher(
|
||||
const T* expanded_permuted_rows,
|
||||
T* reduced_unpermuted_output,
|
||||
const T* bias,
|
||||
const float* scales,
|
||||
const int* expanded_source_row_to_expanded_dest_row,
|
||||
const int* expert_for_source_row,
|
||||
const int64_t num_rows,
|
||||
const int64_t cols,
|
||||
const int64_t k,
|
||||
const int64_t compute_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
cudaStream_t stream) {
|
||||
const int threads = std::min(cols, int64_t(1024));
|
||||
const auto config_final = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
|
||||
finalize_moe_routing_kernel<T, 1>
|
||||
<<<config_final.block_per_grid, threads, 0, stream>>>(
|
||||
expanded_permuted_rows,
|
||||
reduced_unpermuted_output,
|
||||
bias,
|
||||
scales,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
cols,
|
||||
k,
|
||||
compute_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
num_rows);
|
||||
}
|
||||
|
||||
// ========================= TopK Softmax specializations
|
||||
// ===========================
|
||||
template void topk_gating_softmax_kernelLauncher(const float*,
|
||||
float*,
|
||||
float*,
|
||||
int*,
|
||||
int*,
|
||||
float*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const bool,
|
||||
cudaStream_t,
|
||||
const bool);
|
||||
template void topk_gating_softmax_kernelLauncher(const half*,
|
||||
half*,
|
||||
half*,
|
||||
int*,
|
||||
int*,
|
||||
half*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const bool,
|
||||
cudaStream_t,
|
||||
const bool);
|
||||
#ifdef PADDLE_CUDA_BF16
|
||||
template void topk_gating_softmax_kernelLauncher(const __nv_bfloat16*,
|
||||
__nv_bfloat16*,
|
||||
__nv_bfloat16*,
|
||||
int*,
|
||||
int*,
|
||||
__nv_bfloat16*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const bool,
|
||||
cudaStream_t,
|
||||
const bool);
|
||||
#endif
|
||||
// ===================== Specializations for init routing
|
||||
// =========================
|
||||
template void initialize_moe_routing_kernelLauncher(const float*,
|
||||
float*,
|
||||
const int*,
|
||||
int*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
cudaStream_t);
|
||||
template void initialize_moe_routing_kernelLauncher(const half*,
|
||||
half*,
|
||||
const int*,
|
||||
int*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
cudaStream_t);
|
||||
#ifdef PADDLE_CUDA_BF16
|
||||
template void initialize_moe_routing_kernelLauncher(const __nv_bfloat16*,
|
||||
__nv_bfloat16*,
|
||||
const int*,
|
||||
int*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
cudaStream_t);
|
||||
#endif
|
||||
// ==================== Specializations for final routing
|
||||
// ===================================
|
||||
template void finalize_moe_routing_kernelLauncher(const float*,
|
||||
float*,
|
||||
const float*,
|
||||
const float*,
|
||||
const int*,
|
||||
const int*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const bool,
|
||||
const float,
|
||||
cudaStream_t);
|
||||
template void finalize_moe_routing_kernelLauncher(const half*,
|
||||
half*,
|
||||
const half*,
|
||||
const float*,
|
||||
const int*,
|
||||
const int*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const bool,
|
||||
const float,
|
||||
cudaStream_t);
|
||||
#ifdef PADDLE_CUDA_BF16
|
||||
template void finalize_moe_routing_kernelLauncher(const __nv_bfloat16*,
|
||||
__nv_bfloat16*,
|
||||
const __nv_bfloat16*,
|
||||
const float*,
|
||||
const int*,
|
||||
const int*,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const bool,
|
||||
const float,
|
||||
cudaStream_t);
|
||||
#endif
|
417
custom_ops/metax_ops/mc_fused_moe_helper.h
Normal file
417
custom_ops/metax_ops/mc_fused_moe_helper.h
Normal file
@@ -0,0 +1,417 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include "mctlass/numeric_conversion.h"
|
||||
#include "mctlassEx/mctlassEx.h"
|
||||
#include "fused_moe_helper.h"
|
||||
|
||||
|
||||
template <typename ElementA, typename ElementB, typename ElementC>
|
||||
void mc_grouped_gemm_basic_kernel(
|
||||
const ElementA* ptrA,
|
||||
mctlassExOrder_t majorA,
|
||||
const ElementB* ptrB,
|
||||
mctlassExOrder_t majorB,
|
||||
const ElementA* ptrScale,
|
||||
const ElementA* ptrBias,
|
||||
ElementC* ptrC,
|
||||
mctlassExOrder_t majorC,
|
||||
const int *ptrSegInd,
|
||||
int numExperts,
|
||||
int m, // expanded_active_expert_rows
|
||||
int n, // inter_dim
|
||||
int k, // hidden_size
|
||||
mcStream_t stream) {
|
||||
mctlassExHandle_t handle;
|
||||
mctlassExHandleCreate(&handle);
|
||||
|
||||
int* ptrMNumTilesInd;
|
||||
mcMallocAsync((void**)&ptrMNumTilesInd, sizeof(int) * numExperts, stream);
|
||||
|
||||
mctlassExMatrixLayout_t matLayoutA;
|
||||
mctlassExMatrixLayout_t matLayoutB;
|
||||
mctlassExMatrixLayout_t matLayoutC;
|
||||
|
||||
// mat A: (m, k)
|
||||
mctlassExMatrixLayoutCreate(&matLayoutA, mctlassExDataType::MCTLASS_EX_BF16, m, k, k);
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutA, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorA, sizeof(mctlassExOrder_t));
|
||||
// mat B: (num_experts, n, k)
|
||||
mctlassExMatrixLayoutCreate(&matLayoutB, mctlassExDataType::MCTLASS_EX_INT8, k, n, k);
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorB, sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutB, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts, sizeof(int));
|
||||
// mat C: (m, n)
|
||||
mctlassExMatrixLayoutCreate(&matLayoutC, mctlassExDataType::MCTLASS_EX_BF16, m, n, n);
|
||||
mctlassExMatrixLayoutSetAttribute(matLayoutC, mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorC, sizeof(mctlassExOrder_t));
|
||||
// bias: (num_experts, n)
|
||||
// scale: (num, n)
|
||||
|
||||
mctlassExDesc_t mctlass_desc;
|
||||
mctlassExCreateDesc(&mctlass_desc);
|
||||
mctlassExDataType input_type = mctlassExDataType::MCTLASS_EX_BF16;
|
||||
mctlassExDataType scale_type = mctlassExDataType::MCTLASS_EX_INT8;
|
||||
mctlassExDataType compute_type = mctlassExDataType::MCTLASS_EX_FP32;
|
||||
mctlassExEpilogueType epilogue_type = mctlassExEpilogueType::MCTLASS_EX_GEMM_DEFAULT;
|
||||
if (ptrBias) {
|
||||
epilogue_type = mctlassExEpilogueType::MCTLASS_EX_GEMM_BIAS_PERGROUP;
|
||||
}
|
||||
// set scale
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_B_SCALE_POINTER,
|
||||
&ptrScale, sizeof(ptrScale));
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_B_SCALE_TYPE,
|
||||
&scale_type, sizeof(mctlassExDataType));
|
||||
// set bias
|
||||
if (ptrBias) {
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_BIAS_POINTER,
|
||||
&ptrBias, sizeof(ptrBias));
|
||||
}
|
||||
// set coumpute type
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_COMPUTE_TYPE,
|
||||
&compute_type, sizeof(mctlassExDataType));
|
||||
// set epilogue type
|
||||
mctlassExDescSetAttribute(mctlass_desc, mctlassExDescAttributes_t::MCTLASS_EX_GEMM_DESC_EPILOGUE_TYPE,
|
||||
&epilogue_type, sizeof(mctlassExEpilogueType));
|
||||
|
||||
const mctlassExContiguousGroupedGemmAlgo_t algo = mctlassExContiguousGroupedGemmAlgo_t::MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_SEGPTR;
|
||||
int blocksizeM = mctlassExContiguousGroupedGemmGetBlocksizeM(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo);
|
||||
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle, mctlass_desc, matLayoutA, matLayoutB, matLayoutC, &algo, ptrSegInd, ptrMNumTilesInd, numExperts, blocksizeM);
|
||||
|
||||
mctlassExContiguousGroupedGemmBasic(handle, mctlass_desc,
|
||||
ptrA, matLayoutA,
|
||||
ptrB, matLayoutB,
|
||||
ptrC, matLayoutC,
|
||||
ptrSegInd, nullptr, ptrMNumTilesInd,
|
||||
&algo, nullptr, 0, stream);
|
||||
|
||||
mctlassExHandleDestroy(handle);
|
||||
mctlassExMatrixLayoutDestroy(matLayoutA);
|
||||
mctlassExMatrixLayoutDestroy(matLayoutB);
|
||||
mctlassExMatrixLayoutDestroy(matLayoutC);
|
||||
mctlassExDestroyDesc(mctlass_desc);
|
||||
mcFreeAsync(ptrMNumTilesInd, stream);
|
||||
}
|
||||
|
||||
template<typename T, typename ElementA, typename ElementB, typename ElementC>
|
||||
class McMoeHelper {
|
||||
public:
|
||||
McMoeHelper(const std::string gemm_method): gemm_method_(gemm_method) {}
|
||||
|
||||
// -------- getWorkspaceSize -------- //
|
||||
template <typename KeyT>
|
||||
size_t getWorkspaceSize(const int64_t num_rows,
|
||||
const int64_t hidden_size,
|
||||
const int64_t inter_size,
|
||||
const int64_t num_experts,
|
||||
const int64_t k) {
|
||||
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
const size_t padded_experts = AlignTo16(num_experts);
|
||||
const size_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||
// softmax output, permuted_rows and permuted_experts have moved to outside
|
||||
// of moe kernel, allocate them in Encoder or Decoder before invoking
|
||||
// FfnLayer forward.
|
||||
size_t total_ws_bytes =
|
||||
5 * num_moe_inputs *
|
||||
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
|
||||
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
|
||||
total_ws_bytes +=
|
||||
padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_
|
||||
|
||||
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
|
||||
const size_t sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(num_rows));
|
||||
sorter_.update_num_experts(num_experts);
|
||||
|
||||
int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
|
||||
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
|
||||
int64_t remaining_bytes =
|
||||
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
|
||||
bytes_for_intermediate_and_sorting += remaining_bytes;
|
||||
}
|
||||
|
||||
total_ws_bytes +=
|
||||
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
|
||||
// sorting workspace
|
||||
|
||||
int64_t num_softmax_outs = 0;
|
||||
const bool is_pow_2 =
|
||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256) {
|
||||
num_softmax_outs = AlignTo16(num_rows * num_experts);
|
||||
}
|
||||
|
||||
total_ws_bytes += num_softmax_outs * sizeof(float);
|
||||
|
||||
return total_ws_bytes;
|
||||
}
|
||||
|
||||
void computeFFN(const paddle::Tensor *input,
|
||||
const paddle::Tensor *gate_weight,
|
||||
const paddle::Tensor *ffn1_weight,
|
||||
const paddle::Tensor *ffn1_scale,
|
||||
const paddle::Tensor *ffn1_bias,
|
||||
const paddle::Tensor *ffn2_weight,
|
||||
const paddle::Tensor *ffn2_scale,
|
||||
const paddle::Tensor *ffn2_bias,
|
||||
const paddle::Tensor *moe_token_type_ids,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const std::string moe_type,
|
||||
paddle::Tensor *output) {
|
||||
auto *input_activations = input->data<T>();
|
||||
auto *gating_weights = gate_weight->data<float>();
|
||||
const T *fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
|
||||
const T *fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
|
||||
|
||||
auto *output_ = output->data<T>();
|
||||
auto stream = input->stream();
|
||||
auto place = input->place();
|
||||
auto input_type = input->dtype();
|
||||
|
||||
auto input_dims = input->dims();
|
||||
auto ffn1_dims = ffn1_weight->dims();
|
||||
int64_t token_num = 0;
|
||||
if (input_dims.size() == 3) {
|
||||
token_num = input_dims[0] * input_dims[1];
|
||||
} else {
|
||||
token_num = input_dims[0];
|
||||
}
|
||||
const int64_t num_rows = token_num;
|
||||
|
||||
const int64_t hidden_size = ffn1_dims[2];
|
||||
int64_t inter_dim = 0;
|
||||
if (moe_type == "qkv") {
|
||||
inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4];
|
||||
} else {
|
||||
inter_dim = ffn1_dims[1];
|
||||
}
|
||||
|
||||
// if (gemm_method == "weight_only_int4") {
|
||||
// inter_dim = inter_dim * 2;
|
||||
// }
|
||||
|
||||
const int64_t inter_size = inter_dim;
|
||||
const int64_t num_experts = ffn1_dims[0];
|
||||
const int64_t k = moe_topk;
|
||||
|
||||
|
||||
int64_t bytes =
|
||||
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
|
||||
|
||||
// Pointers
|
||||
int *expert_for_source_row;
|
||||
int *source_rows_;
|
||||
int *permuted_rows_;
|
||||
int *permuted_experts_;
|
||||
int *expanded_source_row_to_expanded_dest_row;
|
||||
|
||||
T *permuted_data_;
|
||||
int32_t *total_rows_before_expert_;
|
||||
T *fc1_result_;
|
||||
float *softmax_out_;
|
||||
|
||||
paddle::Tensor ws_ptr_tensor =
|
||||
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
|
||||
int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>();
|
||||
|
||||
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
const int64_t padded_experts = AlignTo16(num_experts);
|
||||
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||
|
||||
expert_for_source_row = reinterpret_cast<int *>(ws_ptr);
|
||||
source_rows_ = expert_for_source_row + num_moe_inputs;
|
||||
permuted_rows_ = source_rows_ + num_moe_inputs;
|
||||
permuted_experts_ = permuted_rows_ + num_moe_inputs;
|
||||
expanded_source_row_to_expanded_dest_row =
|
||||
permuted_experts_ + num_moe_inputs;
|
||||
permuted_data_ = reinterpret_cast<T *>(
|
||||
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
|
||||
total_rows_before_expert_ =
|
||||
reinterpret_cast<int32_t *>(permuted_data_ + buf_size);
|
||||
fc1_result_ =
|
||||
reinterpret_cast<T *>(total_rows_before_expert_ + padded_experts);
|
||||
|
||||
const bool is_pow_2 =
|
||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256) {
|
||||
softmax_out_ = reinterpret_cast<float *>(fc1_result_ + interbuf_size);
|
||||
} else {
|
||||
softmax_out_ = nullptr;
|
||||
}
|
||||
|
||||
paddle::Tensor expert_scales_float_tensor =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
float *expert_scales_float = expert_scales_float_tensor.data<float>();
|
||||
|
||||
float *softmax_max_prob = nullptr;
|
||||
if (group_moe) {
|
||||
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
|
||||
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
// (TODO: check fill success ?)
|
||||
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
||||
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
||||
}
|
||||
|
||||
paddle::Tensor fc1_out_tensor =
|
||||
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
|
||||
T *fc1_out = fc1_out_tensor.data<T>();
|
||||
|
||||
auto input_cast_tensor =
|
||||
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
|
||||
auto gate_tensor =
|
||||
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
|
||||
float *gating_output = gate_tensor.data<float>();
|
||||
|
||||
if (moe_token_type_ids) {
|
||||
auto *moe_token_type_ids_out = moe_token_type_ids->data<int>();
|
||||
moe_token_type_ids_kernelLauncher<float>(gating_output,
|
||||
moe_token_type_ids_out,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
stream);
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float>(gating_output,
|
||||
expert_scales_float,
|
||||
softmax_out_,
|
||||
expert_for_source_row,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
group_moe,
|
||||
stream);
|
||||
|
||||
const int64_t sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
|
||||
|
||||
sorter_.run(fc1_result_,
|
||||
sorter_ws_size_bytes,
|
||||
expert_for_source_row,
|
||||
permuted_experts_,
|
||||
source_rows_,
|
||||
permuted_rows_,
|
||||
k * num_rows,
|
||||
false,
|
||||
stream);
|
||||
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input_activations,
|
||||
permuted_data_,
|
||||
permuted_rows_,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
num_rows,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
stream);
|
||||
|
||||
const int64_t expanded_active_expert_rows = k * num_rows;
|
||||
|
||||
compute_total_rows_before_expert(permuted_experts_,
|
||||
expanded_active_expert_rows,
|
||||
num_experts,
|
||||
total_rows_before_expert_,
|
||||
stream);
|
||||
|
||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ROWMAJOR_ORDER;
|
||||
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_COLUMNMAJOR_ORDER;
|
||||
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(permuted_data_),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn1_weight->data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(ffn1_scale->data<T>()),
|
||||
reinterpret_cast<const ElementA *>(fc1_expert_biases),
|
||||
reinterpret_cast<ElementC *>(fc1_out),
|
||||
row_major,
|
||||
total_rows_before_expert_,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
stream);
|
||||
|
||||
if (moe_type == "ffn") {
|
||||
auto act_out_tensor =
|
||||
paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||
auto act_out = act_out_tensor.data<T>();
|
||||
|
||||
paddle::Tensor fc2_output_tensor =
|
||||
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
|
||||
T *fc2_result = fc2_output_tensor.data<T>();
|
||||
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(act_out),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn2_weight->data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(ffn2_scale->data<T>()),
|
||||
nullptr,
|
||||
reinterpret_cast<ElementC *>(fc2_result),
|
||||
row_major,
|
||||
total_rows_before_expert_,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
stream);
|
||||
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
fc2_result,
|
||||
output_,
|
||||
fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
static_cast<int>(1),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
} else {
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
// fc2_result,
|
||||
fc1_out,
|
||||
output_,
|
||||
fc1_expert_biases, // fc2_expert_biases,
|
||||
reinterpret_cast<float *>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
inter_size,
|
||||
k,
|
||||
static_cast<int>(0),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::string gemm_method_;
|
||||
CubKeyValueSorter sorter_;
|
||||
};
|
274
custom_ops/metax_ops/moe_dispatch.cu
Normal file
274
custom_ops/metax_ops/moe_dispatch.cu
Normal file
@@ -0,0 +1,274 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#pragma GCC diagnostic ignored "-Wunused-function"
|
||||
#pragma once
|
||||
|
||||
#include "fused_moe_helper.h"
|
||||
#include "fused_moe_op.h"
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
|
||||
template <paddle::DataType T>
|
||||
void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
const paddle::Tensor& gating_output,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool topk_only_mode,
|
||||
const int num_rows,
|
||||
const int hidden_size,
|
||||
const int expert_num,
|
||||
paddle::Tensor* permute_input,
|
||||
paddle::Tensor* tokens_expert_prefix_sum,
|
||||
paddle::Tensor* permute_indices_per_token,
|
||||
paddle::Tensor* top_k_weight,
|
||||
paddle::Tensor* top_k_indices) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto stream = input.stream();
|
||||
auto place = input.place();
|
||||
|
||||
if (group_moe) {
|
||||
// Check if expert_num is divisible by moe_topk, else throw an error
|
||||
PADDLE_ENFORCE_EQ(expert_num % moe_topk,
|
||||
0,
|
||||
common::errors::InvalidArgument(
|
||||
"The number of experts (expert_num) "
|
||||
"must be divisible by moe_topk. "
|
||||
"Got expert_num = %d and moe_topk = %d.",
|
||||
expert_num,
|
||||
moe_topk));
|
||||
}
|
||||
|
||||
const int num_moe_inputs = AlignTo16(num_rows * moe_topk);
|
||||
const int bytes = num_moe_inputs * sizeof(int);
|
||||
|
||||
CubKeyValueSorter sorter_;
|
||||
sorter_.update_num_experts(expert_num);
|
||||
|
||||
const int sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(moe_topk * num_rows));
|
||||
const int sort_tmp_in_out_size = num_moe_inputs * 2 * sizeof(int);
|
||||
|
||||
paddle::Tensor ws_ptr_tensor =
|
||||
GetEmptyTensor({bytes + sorter_ws_size_bytes + sort_tmp_in_out_size},
|
||||
paddle::DataType::INT8,
|
||||
place);
|
||||
|
||||
int8_t* ws_ptr = ws_ptr_tensor.data<int8_t>();
|
||||
int* source_rows_ = reinterpret_cast<int*>(ws_ptr);
|
||||
int8_t* sorter_ws_ptr = reinterpret_cast<int8_t*>(ws_ptr + bytes);
|
||||
int* permuted_experts_ =
|
||||
reinterpret_cast<int*>(sorter_ws_ptr + sorter_ws_size_bytes);
|
||||
int* permuted_rows_ = permuted_experts_ + num_moe_inputs;
|
||||
|
||||
int* expert_for_source_row = top_k_indices->data<int>();
|
||||
|
||||
float* softmax_max_prob = nullptr;
|
||||
if (group_moe) {
|
||||
paddle::Tensor softmax_max_prob_tensor =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
||||
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
||||
}
|
||||
|
||||
float* softmax_out_;
|
||||
|
||||
const bool is_pow_2 =
|
||||
(expert_num != 0) && ((expert_num & (expert_num - 1)) == 0);
|
||||
|
||||
paddle::Tensor softmax_buffer;
|
||||
|
||||
if (!is_pow_2 || expert_num > 256 || group_moe) {
|
||||
softmax_buffer = GetEmptyTensor(
|
||||
{num_rows * expert_num}, paddle::DataType::FLOAT32, place);
|
||||
softmax_out_ = softmax_buffer.data<float>();
|
||||
} else {
|
||||
softmax_out_ = nullptr;
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float>(gating_output.data<float>(),
|
||||
top_k_weight->data<float>(),
|
||||
softmax_out_,
|
||||
expert_for_source_row,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
expert_num,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
stream,
|
||||
topk_only_mode);
|
||||
|
||||
sorter_.run(reinterpret_cast<void*>(sorter_ws_ptr),
|
||||
sorter_ws_size_bytes,
|
||||
expert_for_source_row,
|
||||
permuted_experts_,
|
||||
source_rows_,
|
||||
permuted_rows_,
|
||||
moe_topk * num_rows,
|
||||
false,
|
||||
stream);
|
||||
|
||||
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input.data<data_t>(),
|
||||
permute_input->data<data_t>(),
|
||||
permuted_rows_,
|
||||
permute_indices_per_token->data<int32_t>(),
|
||||
num_rows,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
moe_topk,
|
||||
stream);
|
||||
|
||||
|
||||
compute_total_rows_before_expert(
|
||||
permuted_experts_,
|
||||
moe_topk * num_rows,
|
||||
expert_num,
|
||||
tokens_expert_prefix_sum->data<int32_t>(),
|
||||
stream);
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& gating_output,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool topk_only_mode) {
|
||||
const auto input_type = input.dtype();
|
||||
auto place = input.place();
|
||||
int token_rows = 0;
|
||||
auto input_dims = input.dims();
|
||||
auto gating_dims = gating_output.dims();
|
||||
const int expert_num = gating_dims[gating_dims.size() - 1];
|
||||
|
||||
if (input_dims.size() == 3) {
|
||||
token_rows = input_dims[0] * input_dims[1];
|
||||
} else {
|
||||
token_rows = input_dims[0];
|
||||
}
|
||||
const int num_rows = token_rows;
|
||||
const int hidden_size = input.dims()[input_dims.size() - 1];
|
||||
|
||||
auto permute_input =
|
||||
GetEmptyTensor({moe_topk * num_rows, hidden_size}, input_type, place);
|
||||
// correspond to the weighted coefficients of the results from each expert.
|
||||
auto top_k_weight =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
auto top_k_indices =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::INT32, place);
|
||||
|
||||
auto tokens_expert_prefix_sum =
|
||||
GetEmptyTensor({expert_num}, paddle::DataType::INT32, place);
|
||||
auto permute_indices_per_token =
|
||||
GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place);
|
||||
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,
|
||||
gating_output,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
topk_only_mode,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
expert_num,
|
||||
&permute_input,
|
||||
&tokens_expert_prefix_sum,
|
||||
&permute_indices_per_token,
|
||||
&top_k_weight,
|
||||
&top_k_indices);
|
||||
break;
|
||||
// case paddle::DataType::FLOAT16:
|
||||
// MoeDispatchKernel<paddle::DataType::FLOAT16>(input,
|
||||
// gating_output,
|
||||
// moe_topk,
|
||||
// group_moe,
|
||||
// topk_only_mode,
|
||||
// num_rows,
|
||||
// hidden_size,
|
||||
// expert_num,
|
||||
// &permute_input,
|
||||
// &tokens_expert_prefix_sum,
|
||||
// &permute_indices_per_token,
|
||||
// &top_k_weight,
|
||||
// &top_k_indices);
|
||||
// break;
|
||||
default:
|
||||
PD_THROW("Only support bf16 for MoeDispatchKernel");
|
||||
}
|
||||
return {permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
permute_indices_per_token,
|
||||
top_k_weight,
|
||||
top_k_indices};
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::vector<int64_t>& gating_output_shape,
|
||||
const int moe_topk) {
|
||||
int token_rows = -1;
|
||||
|
||||
if (input_shape.size() == 3) {
|
||||
token_rows = input_shape[0] * input_shape[1];
|
||||
} else {
|
||||
token_rows = input_shape[0];
|
||||
}
|
||||
const int expert_num = gating_output_shape[gating_output_shape.size() - 1];
|
||||
const int num_rows = token_rows;
|
||||
const int hidden_size = input_shape[input_shape.size() - 1];
|
||||
|
||||
return {{moe_topk * num_rows, hidden_size},
|
||||
{expert_num},
|
||||
{moe_topk, num_rows},
|
||||
{num_rows, moe_topk},
|
||||
{num_rows, moe_topk}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
||||
const paddle::DataType& input_dtype,
|
||||
const paddle::DataType& gating_output_dtype,
|
||||
const int moe_topk) {
|
||||
return {input_dtype,
|
||||
paddle::DataType::INT64,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::FLOAT32,
|
||||
paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(moe_expert_dispatch)
|
||||
.Inputs({"input", "gating_output"})
|
||||
.Outputs({"permute_input",
|
||||
"tokens_expert_prefix_sum",
|
||||
"permute_indices_per_token",
|
||||
"top_k_weight",
|
||||
"top_k_indices"})
|
||||
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));
|
173
custom_ops/metax_ops/moe_ffn.cu
Normal file
173
custom_ops/metax_ops/moe_ffn.cu
Normal file
@@ -0,0 +1,173 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
#include "mc_fused_moe_helper.h"
|
||||
#include "helper.h"
|
||||
|
||||
template <paddle::DataType T, typename ElementA, typename ElementB, typename ElementC>
|
||||
void McMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const std::string& quant_method,
|
||||
paddle::Tensor ffn_out) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto ffn_out_ptr = ffn_out.data<data_t>();
|
||||
auto permuted_input_ptr = permute_input.data<data_t>();
|
||||
auto place = permute_input.place();
|
||||
auto input_type = permute_input.dtype();
|
||||
auto stream = permute_input.stream();
|
||||
|
||||
const int expanded_active_expert_rows = permute_input.dims()[0]; // permute_input.dims(): m, k
|
||||
const int num_experts = ffn1_weight.dims()[0]; // batchsize
|
||||
const int hidden_size = ffn1_weight.dims()[2]; // n
|
||||
int inter_dim = ffn1_weight.dims()[1]; // k
|
||||
|
||||
const int64_t inter_size = inter_dim; // since weight_only_int_8
|
||||
paddle::Tensor fc1_out_tensor = GetEmptyTensor(
|
||||
{expanded_active_expert_rows, inter_size}, input_type, place);
|
||||
auto fc1_out_ptr = fc1_out_tensor.data<data_t>();
|
||||
|
||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ROWMAJOR_ORDER;
|
||||
mctlassExOrder_t column_major = mctlassExOrder_t::MCTLASS_EX_COLUMNMAJOR_ORDER;
|
||||
|
||||
// ffn1
|
||||
auto fc1_expert_biases =
|
||||
ffn1_bias
|
||||
? const_cast<paddle::Tensor*>(ffn1_bias.get_ptr())->data<data_t>()
|
||||
: nullptr;
|
||||
auto fc1_expert_scales = const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())->data<data_t>();
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(permuted_input_ptr),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn1_weight.data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(fc1_expert_scales),
|
||||
reinterpret_cast<const ElementA *>(fc1_expert_biases),
|
||||
reinterpret_cast<ElementC *>(fc1_out_ptr),
|
||||
row_major,
|
||||
tokens_expert_prefix_sum.data<int>(),
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
inter_dim,
|
||||
hidden_size,
|
||||
stream);
|
||||
|
||||
// swiglu
|
||||
auto act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||
auto act_out = act_out_tensor.data<data_t>();
|
||||
|
||||
auto fc2_expert_scales = const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())->data<data_t>();
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA *>(act_out),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB *>(ffn2_weight.data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA *>(fc2_expert_scales),
|
||||
nullptr,
|
||||
reinterpret_cast<ElementC *>(ffn_out_ptr),
|
||||
row_major,
|
||||
tokens_expert_prefix_sum.data<int>(),
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_dim / 2,
|
||||
stream);
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
const paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const std::string& quant_method) {
|
||||
assert(quant_method == "weight_only_int8");
|
||||
const auto input_type = permute_input.dtype();
|
||||
auto ffn_out = paddle::empty_like(permute_input);
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
McMoeFFNKernel<paddle::DataType::BFLOAT16, maca_bfloat16, int8_t, maca_bfloat16>(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
ffn1_bias,
|
||||
ffn1_scale,
|
||||
ffn2_scale,
|
||||
quant_method,
|
||||
ffn_out);
|
||||
break;
|
||||
// case paddle::DataType::FLOAT16:
|
||||
// MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
|
||||
// tokens_expert_prefix_sum,
|
||||
// ffn1_weight,
|
||||
// ffn2_weight,
|
||||
// ffn1_bias,
|
||||
// ffn1_scale,
|
||||
// ffn2_scale,
|
||||
// quant_method,
|
||||
// ffn_out);
|
||||
// break;
|
||||
default:
|
||||
PD_THROW("Only support bf16 for MoeExpertFFN");
|
||||
}
|
||||
return {ffn_out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
const std::vector<int64_t>& permute_input_shape,
|
||||
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
|
||||
const std::vector<int64_t>& ffn1_weight_shape,
|
||||
const std::vector<int64_t>& ffn2_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape) {
|
||||
return {permute_input_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
const paddle::DataType& permute_input_dtype,
|
||||
const paddle::DataType& tokens_expert_prefix_sum_dtype,
|
||||
const paddle::DataType& ffn1_weight_dtype,
|
||||
const paddle::DataType& ffn2_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn1_scale_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_scale_dtype) {
|
||||
return {permute_input_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(moe_expert_ffn)
|
||||
.Inputs({"permute_input",
|
||||
"tokens_expert_prefix_sum",
|
||||
"ffn1_weight",
|
||||
"ffn2_weight",
|
||||
paddle::Optional("ffn1_bias"),
|
||||
paddle::Optional("ffn1_scale"),
|
||||
paddle::Optional("ffn2_scale")})
|
||||
.Outputs({"output_tensor"})
|
||||
.Attrs({"quant_method:std::string"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));
|
143
custom_ops/metax_ops/moe_reduce.cu
Normal file
143
custom_ops/metax_ops/moe_reduce.cu
Normal file
@@ -0,0 +1,143 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "helper.h"
|
||||
#include "fused_moe_helper.h"
|
||||
#include "fused_moe_op.h"
|
||||
|
||||
template <paddle::DataType T>
|
||||
void MoeReduceKernel(const paddle::Tensor& ffn_out,
|
||||
const paddle::Tensor& top_k_weight,
|
||||
const paddle::Tensor& permute_indices_per_token,
|
||||
const paddle::Tensor& top_k_indices,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const int num_rows,
|
||||
const int hidden_size,
|
||||
const int topk,
|
||||
paddle::Tensor* output) {
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
auto stream = ffn_out.stream();
|
||||
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
ffn_out.data<data_t>(),
|
||||
output->data<data_t>(),
|
||||
ffn2_bias ? ffn2_bias->data<data_t>() : nullptr,
|
||||
top_k_weight.data<float>(),
|
||||
permute_indices_per_token.data<int32_t>(),
|
||||
top_k_indices.data<int>(),
|
||||
num_rows,
|
||||
hidden_size,
|
||||
topk,
|
||||
static_cast<int>(1),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertReduce(
|
||||
const paddle::Tensor& ffn_out,
|
||||
const paddle::Tensor& top_k_weight,
|
||||
const paddle::Tensor& permute_indices_per_token,
|
||||
const paddle::Tensor& top_k_indices,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor) {
|
||||
const auto input_type = ffn_out.dtype();
|
||||
auto place = ffn_out.place();
|
||||
|
||||
const int topk = top_k_indices.dims()[1];
|
||||
const int num_rows = ffn_out.dims()[0] / topk;
|
||||
const int hidden_size = ffn_out.dims()[1];
|
||||
|
||||
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
|
||||
|
||||
// Avoids ‘invalid configuration argument’ when we launch the kernel.
|
||||
if (ffn_out.dims()[0] == 0) return {output};
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeReduceKernel<paddle::DataType::BFLOAT16>(ffn_out,
|
||||
top_k_weight,
|
||||
permute_indices_per_token,
|
||||
top_k_indices,
|
||||
ffn2_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
topk,
|
||||
&output);
|
||||
break;
|
||||
// case paddle::DataType::FLOAT16:
|
||||
// MoeReduceKernel<paddle::DataType::FLOAT16>(ffn_out,
|
||||
// top_k_weight,
|
||||
// permute_indices_per_token,
|
||||
// top_k_indices,
|
||||
// ffn2_bias,
|
||||
// norm_topk_prob,
|
||||
// routed_scaling_factor,
|
||||
// num_rows,
|
||||
// hidden_size,
|
||||
// topk,
|
||||
// &output);
|
||||
// break;
|
||||
default:
|
||||
PD_THROW("Only support bf16 for MoeDispatchKernel");
|
||||
}
|
||||
return {output};
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
|
||||
const std::vector<int64_t>& ffn_out_shape,
|
||||
const std::vector<int64_t>& top_k_weight_shape,
|
||||
const std::vector<int64_t>& permute_indices_per_token_shape,
|
||||
const std::vector<int64_t>& top_k_indices_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape) {
|
||||
const int topk = top_k_indices_shape[1];
|
||||
std::vector<int64_t> fused_moe_out_shape = {ffn_out_shape[0] / topk,
|
||||
ffn_out_shape[1]};
|
||||
|
||||
return {fused_moe_out_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoeExpertReduceInferDtype(
|
||||
const paddle::DataType& ffn_out_dtype,
|
||||
const paddle::DataType& top_k_weight_dtype,
|
||||
const paddle::DataType& permute_indices_per_token_dtype,
|
||||
const paddle::DataType& top_k_indices_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_bias_dtype) {
|
||||
return {ffn_out_dtype};
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_OP(moe_expert_reduce)
|
||||
.Inputs({"ffn_out",
|
||||
"top_k_weight",
|
||||
"permute_indices_per_token",
|
||||
"top_k_indices",
|
||||
paddle::Optional("ffn2_bias")})
|
||||
.Outputs({"output"})
|
||||
.Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertReduce))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertReduceInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertReduceInferDtype));
|
@@ -37,6 +37,52 @@ def load_module_from_path(module_name, path):
|
||||
return module
|
||||
|
||||
|
||||
def update_git_repo():
|
||||
try:
|
||||
print("update third party repo...", flush=True)
|
||||
original_dir = os.getcwd()
|
||||
submodule_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
third_party_path = os.path.join(submodule_dir, "third_party")
|
||||
root_path = Path(third_party_path)
|
||||
|
||||
# check if third_party is empty
|
||||
update_third_party = False
|
||||
for dirpath in root_path.iterdir():
|
||||
if dirpath.is_dir():
|
||||
has_content = any(dirpath.iterdir())
|
||||
if not has_content:
|
||||
update_third_party = True
|
||||
|
||||
if update_third_party:
|
||||
os.chdir(submodule_dir)
|
||||
subprocess.run(
|
||||
"git submodule sync --recursive && git submodule update --init --recursive",
|
||||
shell=True,
|
||||
check=True,
|
||||
text=True,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"\033[33m[===WARNING===]third_party directory already exists, skip clone and update.\033[0m",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# apply deep gemm patch
|
||||
deep_gemm_dir = "third_party/DeepGEMM"
|
||||
dst_path = os.path.join(submodule_dir, deep_gemm_dir)
|
||||
patch = "0001-DeepGEMM-95e81b3.patch"
|
||||
patch_source = os.path.join(submodule_dir, patch)
|
||||
patch_destination = os.path.join(dst_path, patch)
|
||||
if not os.path.exists(patch_destination):
|
||||
shutil.copy(patch_source, patch_destination)
|
||||
apply_cmd = ["git", "apply", patch]
|
||||
os.chdir(dst_path)
|
||||
subprocess.run(apply_cmd, check=True)
|
||||
os.chdir(original_dir)
|
||||
except subprocess.CalledProcessError:
|
||||
raise Exception("Git submodule update and apply patch failed. Maybe network connection is poor.")
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent
|
||||
|
||||
# cannot import envs directly because it depends on fastdeploy,
|
||||
@@ -46,6 +92,8 @@ envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "fastdeploy", "envs.
|
||||
archs = json.loads(envs.FD_BUILDING_ARCS)
|
||||
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
|
||||
|
||||
update_git_repo()
|
||||
|
||||
|
||||
def download_and_extract(url, destination_directory):
|
||||
"""
|
||||
@@ -78,52 +126,6 @@ def download_and_extract(url, destination_directory):
|
||||
print(f"Error extracting file: {e}")
|
||||
|
||||
|
||||
def clone_git_repo(version, repo_url, destination_path):
|
||||
"""
|
||||
Clone git repo to destination path.
|
||||
"""
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"-b",
|
||||
version,
|
||||
"--single-branch",
|
||||
repo_url,
|
||||
destination_path,
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
|
||||
def process_git_repo(cur_path, dst_path, commit_id=None, patch=None):
|
||||
"""
|
||||
reset git repo to destination commit and apply patch.
|
||||
"""
|
||||
if commit_id is not None:
|
||||
reset_cmd = ["git", "reset", "--hard", commit_id]
|
||||
if patch is not None:
|
||||
patch_source = os.path.join(cur_path, patch)
|
||||
patch_destination = os.path.join(dst_path, patch)
|
||||
shutil.copy(patch_source, patch_destination)
|
||||
apply_cmd = ["git", "apply", patch]
|
||||
|
||||
try:
|
||||
os.chdir(dst_path)
|
||||
if commit_id is not None:
|
||||
subprocess.run(reset_cmd, check=True)
|
||||
if patch is not None:
|
||||
subprocess.run(apply_cmd, check=True)
|
||||
os.chdir(cur_path)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
|
||||
def get_sm_version(archs):
|
||||
"""
|
||||
Get sm version of paddle.
|
||||
@@ -191,13 +193,6 @@ def find_end_files(directory, end_str):
|
||||
if paddle.is_compiled_with_rocm():
|
||||
# NOTE(@duanyanhui): paddle.is_compiled_with_cuda() returns True when paddle compiled with rocm.
|
||||
# so we need to check if paddle compiled with rocm at first.
|
||||
json_dir = "third_party/nlohmann_json"
|
||||
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
||||
if not os.path.exists(json_dir):
|
||||
os.makedirs(json_dir)
|
||||
clone_git_repo("v3.11.3", "https://bgithub.xyz/nlohmann/json.git", json_dir)
|
||||
if not os.listdir(json_dir):
|
||||
raise ValueError("Git clone nlohmann_json failed!")
|
||||
sources = [
|
||||
"gpu_ops/save_with_output_msg.cc",
|
||||
"gpu_ops/get_output.cc",
|
||||
@@ -213,6 +208,7 @@ if paddle.is_compiled_with_rocm():
|
||||
"gpu_ops/rebuild_padding.cu",
|
||||
"gpu_ops/step.cu",
|
||||
"gpu_ops/set_data_ipc.cu",
|
||||
"gpu_ops/unset_data_ipc.cu",
|
||||
"gpu_ops/moe/tritonmoe_preprocess.cu",
|
||||
"gpu_ops/step_system_cache.cu",
|
||||
"gpu_ops/get_output_ep.cc",
|
||||
@@ -283,6 +279,7 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/beam_search_softmax.cu",
|
||||
"gpu_ops/rebuild_padding.cu",
|
||||
"gpu_ops/set_data_ipc.cu",
|
||||
"gpu_ops/unset_data_ipc.cu",
|
||||
"gpu_ops/read_data_ipc.cu",
|
||||
"gpu_ops/enforce_generation.cu",
|
||||
"gpu_ops/dequant_int8.cu",
|
||||
@@ -316,28 +313,6 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/ipc_sent_key_value_cache_by_remote_ptr.cu",
|
||||
]
|
||||
|
||||
cutlass_dir = "third_party/cutlass"
|
||||
if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir):
|
||||
if not os.path.exists(cutlass_dir):
|
||||
os.makedirs(cutlass_dir)
|
||||
clone_git_repo("v3.8.0", "https://github.com/NVIDIA/cutlass.git", cutlass_dir)
|
||||
if not os.listdir(cutlass_dir):
|
||||
raise ValueError("Git clone cutlass failed!")
|
||||
|
||||
# deep gemm
|
||||
deep_gemm_dir = "third_party/DeepGEMM"
|
||||
if not os.path.exists(deep_gemm_dir) or not os.listdir(deep_gemm_dir):
|
||||
if not os.path.exists(deep_gemm_dir):
|
||||
os.makedirs(deep_gemm_dir)
|
||||
clone_git_repo("main", "https://github.com/deepseek-ai/DeepGEMM.git", deep_gemm_dir)
|
||||
if not os.listdir(deep_gemm_dir):
|
||||
raise ValueError("Git clone DeepGEMM failed!")
|
||||
cur_path = os.path.dirname(os.path.abspath(__file__))
|
||||
dst_path = os.path.join(cur_path, deep_gemm_dir)
|
||||
commit_id = "95e81b3dd6704e279e5f4757c5b94776ac988a8d"
|
||||
patch = "0001-DeepGEMM-95e81b3.patch"
|
||||
process_git_repo(cur_path, dst_path, commit_id, patch)
|
||||
|
||||
dg_third_party_include_dirs = (
|
||||
"third_party/cutlass/include/cute",
|
||||
"third_party/cutlass/include/cutlass",
|
||||
@@ -365,14 +340,6 @@ elif paddle.is_compiled_with_cuda():
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to copy from {src_dir} to {dst_dir}: {e}")
|
||||
|
||||
json_dir = "third_party/nlohmann_json"
|
||||
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
||||
if not os.path.exists(json_dir):
|
||||
os.makedirs(json_dir)
|
||||
clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git", json_dir)
|
||||
if not os.listdir(json_dir):
|
||||
raise ValueError("Git clone nlohmann_json failed!")
|
||||
|
||||
cc_compile_args = []
|
||||
nvcc_compile_args = get_gencode_flags(archs)
|
||||
nvcc_compile_args += ["-DPADDLE_DEV"]
|
||||
@@ -542,7 +509,7 @@ elif paddle.is_compiled_with_cuda():
|
||||
include_package_data=True,
|
||||
)
|
||||
elif paddle.is_compiled_with_xpu():
|
||||
assert False, "In XPU, we should use setup_ops.py in xpu_ops/src, not this."
|
||||
assert False, "For XPU, please use setup_ops.py in the xpu_ops directory to compile custom ops."
|
||||
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||
setup(
|
||||
name="fastdeploy_ops",
|
||||
@@ -571,6 +538,8 @@ elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||
"iluvatar_ops/moe_dispatch.cu",
|
||||
"iluvatar_ops/moe_reduce.cu",
|
||||
"iluvatar_ops/paged_attn.cu",
|
||||
"iluvatar_ops/prefill_fused_attn.cu",
|
||||
"iluvatar_ops/mixed_fused_attn.cu",
|
||||
"iluvatar_ops/w8a16_group_gemm.cu",
|
||||
"iluvatar_ops/runtime/iluvatar_context.cc",
|
||||
],
|
||||
@@ -593,13 +562,6 @@ elif paddle.is_compiled_with_custom_device("gcu"):
|
||||
)
|
||||
elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
maca_path = os.getenv("MACA_PATH", "/opt/maca")
|
||||
json_dir = "third_party/nlohmann_json"
|
||||
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
||||
if not os.path.exists(json_dir):
|
||||
os.makedirs(json_dir)
|
||||
clone_git_repo("v3.11.3", "https://gitee.com/learnlov/mirrors_nlohmann_json.git", json_dir)
|
||||
if not os.listdir(json_dir):
|
||||
raise ValueError("Git clone nlohmann_json failed!")
|
||||
sources = [
|
||||
"gpu_ops/update_inputs_v1.cu",
|
||||
"gpu_ops/save_with_output_msg.cc",
|
||||
@@ -635,6 +597,10 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
"gpu_ops/moe/tritonmoe_preprocess.cu",
|
||||
"gpu_ops/moe/moe_topk_select.cu",
|
||||
"gpu_ops/recover_decode_task.cu",
|
||||
"metax_ops/moe_dispatch.cu",
|
||||
"metax_ops/moe_ffn.cu",
|
||||
"metax_ops/moe_reduce.cu",
|
||||
"metax_ops/fused_moe.cu",
|
||||
]
|
||||
|
||||
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
|
||||
@@ -655,7 +621,7 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
],
|
||||
},
|
||||
library_dirs=[os.path.join(maca_path, "lib")],
|
||||
extra_link_args=["-lruntime_cu"],
|
||||
extra_link_args=["-lruntime_cu", "-lmctlassEx"],
|
||||
include_dirs=[
|
||||
os.path.join(maca_path, "include"),
|
||||
os.path.join(maca_path, "include/mcr"),
|
||||
@@ -663,6 +629,8 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
],
|
||||
),
|
||||
)
|
||||
elif paddle.is_compiled_with_custom_device("intel_hpu"):
|
||||
pass
|
||||
else:
|
||||
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
|
||||
|
||||
|
1
custom_ops/third_party/DeepGEMM
vendored
Submodule
1
custom_ops/third_party/DeepGEMM
vendored
Submodule
Submodule custom_ops/third_party/DeepGEMM added at 95e81b3dd6
1
custom_ops/third_party/cutlass
vendored
Submodule
1
custom_ops/third_party/cutlass
vendored
Submodule
Submodule custom_ops/third_party/cutlass added at afa1772203
1
custom_ops/third_party/nlohmann_json
vendored
Submodule
1
custom_ops/third_party/nlohmann_json
vendored
Submodule
Submodule custom_ops/third_party/nlohmann_json added at 9cca280a4d
@@ -27,7 +27,7 @@ import paddle
|
||||
from paddle.utils.cpp_extension import CppExtension, setup
|
||||
|
||||
current_file = Path(__file__).resolve()
|
||||
base_dir = current_file.parent
|
||||
base_dir = os.path.join(current_file.parent, "src")
|
||||
|
||||
|
||||
def build_plugin(CLANG_PATH, XRE_INC_DIR, XRE_LIB_DIR, XDNN_INC_DIR, XDNN_LIB_DIR):
|
||||
@@ -136,33 +136,8 @@ def xpu_setup_ops():
|
||||
# build plugin
|
||||
build_plugin(CLANG_PATH, XRE_INC_PATH, XRE_LIB_DIR, XDNN_INC_PATH, XDNN_LIB_DIR)
|
||||
|
||||
ops = [
|
||||
# custom ops
|
||||
"./ops/save_with_output_msg.cc",
|
||||
"./ops/stop_generation_multi_ends.cc",
|
||||
"./ops/set_value_by_flags_and_idx.cc",
|
||||
"./ops/get_token_penalty_multi_scores.cc",
|
||||
"./ops/get_padding_offset.cc",
|
||||
"./ops/update_inputs.cc",
|
||||
"./ops/recover_decode_task.cc",
|
||||
"./ops/update_inputs_v1.cc",
|
||||
"./ops/get_output.cc",
|
||||
"./ops/step.cc",
|
||||
"./ops/get_infer_param.cc",
|
||||
"./ops/adjust_batch.cc",
|
||||
"./ops/gather_next_token.cc",
|
||||
"./ops/block_attn.cc",
|
||||
"./ops/moe_layer.cc",
|
||||
"./ops/weight_quantize_xpu.cc",
|
||||
# device manage ops
|
||||
"./ops/device/get_context_gm_max_mem_demand.cc",
|
||||
"./ops/device/get_free_global_memory.cc",
|
||||
"./ops/device/get_total_global_memory.cc",
|
||||
"./ops/device/get_used_global_memory.cc",
|
||||
]
|
||||
ops = [os.path.join(base_dir, op) for op in ops]
|
||||
|
||||
for root, dirs, files in os.walk(base_dir / "ops/mtp_ops"):
|
||||
ops = []
|
||||
for root, dirs, files in os.walk(os.path.join(base_dir, "ops")):
|
||||
for file in files:
|
||||
if file.endswith(".cc"):
|
||||
ops.append(os.path.join(root, file))
|
@@ -41,7 +41,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
const paddle::Tensor &encoder_seq_lod_cpu,
|
||||
const paddle::Tensor &encoder_batch_map_cpu,
|
||||
const paddle::Tensor &decoder_context_len_cpu,
|
||||
const paddle::Tensor &decoder_batch_map_cpu) {
|
||||
const paddle::Tensor &decoder_batch_map_cpu,
|
||||
const std::string &pos_emb_type="NORMAL",
|
||||
bool rope_3d=false) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx =
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
@@ -72,6 +74,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
|
||||
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
|
||||
int total_enc_len = total_enc_len_tensor.data<int32_t>()[0];
|
||||
int rope_max_seqlen = 0;
|
||||
int rope_3d_num_seqs = 1;
|
||||
if (rope_3d) {
|
||||
rope_max_seqlen = rotary_embs.dims()[3];
|
||||
rope_3d_num_seqs = rotary_embs.dims()[0];
|
||||
} else {
|
||||
rope_max_seqlen = rotary_embs.dims()[2];
|
||||
}
|
||||
|
||||
auto block_attn_out =
|
||||
paddle::full({token_num, hidden_dim}, -1, qkv.type(), qkv.place());
|
||||
@@ -151,10 +161,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
prefix_lens_vp, // start_tokens
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size
|
||||
rotary_embs.dims()[2], // max_seqlen
|
||||
rope_max_seqlen, // max_seqlen
|
||||
param.head_num, param.kv_head_num, param.head_dim,
|
||||
param.max_batch_size, block_size, max_block_per_seq, "BLHD",
|
||||
"HLD", "NORMAL",
|
||||
"HLD", pos_emb_type,
|
||||
!p_kcache_perhead_scale.defined()
|
||||
? nullptr
|
||||
: p_kcache_perhead_scale.data<float>() +
|
||||
@@ -246,10 +256,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
vsl.slot_mapping_vp, // real_batch
|
||||
param.batch_size, // batch_size
|
||||
1, // emb_batch_size
|
||||
rotary_embs.dims()[2], // max_seqlen TODO!!double check
|
||||
rope_max_seqlen, // max_seqlen
|
||||
param.head_num, param.kv_head_num, param.head_dim,
|
||||
param.max_batch_size, block_size, max_block_per_seq, "BLHD", "HLD",
|
||||
"NORMAL",
|
||||
pos_emb_type,
|
||||
!p_kcache_perhead_scale.defined()
|
||||
? nullptr
|
||||
: p_kcache_perhead_scale.data<float>() +
|
||||
@@ -260,7 +270,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
|
||||
param.kv_head_num, // v_cache_scale_inv
|
||||
nullptr, // k_cache_zp
|
||||
nullptr, // v_cache_zp
|
||||
false); // b_c8_pc
|
||||
false, // b_c8_pc
|
||||
rope_3d, // rope_3d
|
||||
rope_3d_num_seqs);
|
||||
XFTBLOCK_CHECK_EQ(ret, api::SUCCESS);
|
||||
|
||||
// attn decode
|
||||
@@ -314,6 +326,7 @@ PD_BUILD_OP(block_attn)
|
||||
"decoder_context_len_cpu",
|
||||
"decoder_batch_map_cpu",
|
||||
})
|
||||
.Attrs({"pos_emb_type:std::string", "rope_3d:bool"})
|
||||
.Outputs({"block_attn_out"})
|
||||
.SetKernelFn(PD_KERNEL(BlockAttnKernel))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))
|
||||
|
225
custom_ops/xpu_ops/src/ops/fused_rms_norm.cc
Normal file
225
custom_ops/xpu_ops/src/ops/fused_rms_norm.cc
Normal file
@@ -0,0 +1,225 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <infer_ops.h>
|
||||
#include <functional>
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/backends/xpu/enforce_xpu.h"
|
||||
#include "utility/debug.h"
|
||||
#include "utility/env.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
XPU_DECLARE_BOOL(ENABLE_XVLLM_SDNN_INFER, false);
|
||||
namespace api = baidu::xpu::api;
|
||||
|
||||
template <typename T>
|
||||
std::vector<paddle::Tensor> RmsNormKernel(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::optional<paddle::Tensor>& bias,
|
||||
const paddle::optional<paddle::Tensor>& residual,
|
||||
const paddle::Tensor& norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& norm_bias,
|
||||
const float epsilon,
|
||||
const int begin_norm_axis,
|
||||
const float quant_scale,
|
||||
const int quant_round_type,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound) {
|
||||
using XPU_T = typename XPUTypeTrait<T>::Type;
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
|
||||
int ret = -1;
|
||||
auto x_shape = x.shape();
|
||||
PD_CHECK(quant_scale <= 0, "Quantization is not supported");
|
||||
PD_CHECK(begin_norm_axis > 0 && begin_norm_axis <= x_shape.size(),
|
||||
"begin_norm_axis check fail");
|
||||
PD_CHECK(norm_bias.get_ptr() == nullptr,
|
||||
"rms norm kernel don't support norm_bias");
|
||||
|
||||
int64_t m = std::accumulate(x_shape.begin(),
|
||||
x_shape.begin() + begin_norm_axis,
|
||||
static_cast<int64_t>(1),
|
||||
std::multiplies<int64_t>());
|
||||
int64_t n = std::accumulate(x_shape.begin() + begin_norm_axis,
|
||||
x_shape.end(),
|
||||
static_cast<int64_t>(1),
|
||||
std::multiplies<int64_t>());
|
||||
|
||||
PD_CHECK(n == norm_weight.shape()[0],
|
||||
"The product from begin_norm_axis to the last axis of x must be "
|
||||
"equal to the norm_weight's shape[0]");
|
||||
if (bias.get_ptr()) {
|
||||
PD_CHECK(n == bias.get_ptr()->shape()[0],
|
||||
"The product from begin_norm_axis to the last axis of x must be "
|
||||
"equal to the bias's shape[0]");
|
||||
}
|
||||
|
||||
paddle::Tensor out = paddle::empty(x_shape, x.dtype(), x.place());
|
||||
paddle::Tensor residual_out = paddle::empty(x_shape, x.dtype(), x.place());
|
||||
const XPU_T* x_data = reinterpret_cast<const XPU_T*>(x.data<T>());
|
||||
const XPU_T* norm_weight_data =
|
||||
reinterpret_cast<const XPU_T*>(norm_weight.data<T>());
|
||||
const XPU_T* bias_data =
|
||||
bias.get_ptr() ? reinterpret_cast<const XPU_T*>(bias.get_ptr()->data<T>())
|
||||
: nullptr;
|
||||
const XPU_T* residual_data =
|
||||
residual.get_ptr()
|
||||
? reinterpret_cast<const XPU_T*>(residual.get_ptr()->data<T>())
|
||||
: nullptr;
|
||||
XPU_T* out_data = reinterpret_cast<XPU_T*>(const_cast<T*>(out.data<T>()));
|
||||
XPU_T* residual_out_data = nullptr;
|
||||
if (residual_data) {
|
||||
residual_out_data =
|
||||
reinterpret_cast<XPU_T*>(const_cast<T*>(residual_out.data<T>()));
|
||||
}
|
||||
|
||||
XPU_T* add_out_data = const_cast<XPU_T*>(x_data);
|
||||
if (bias_data) {
|
||||
ret = api::broadcast_add(
|
||||
xpu_ctx->x_context(), x_data, bias_data, out_data, {m, n}, {n});
|
||||
PD_CHECK(ret == 0, "broadcast_add");
|
||||
add_out_data = out_data;
|
||||
}
|
||||
|
||||
bool use_sdnn = FLAGS_ENABLE_XVLLM_SDNN_INFER;
|
||||
if (residual_data) {
|
||||
ret = infer_ops::add_rms_layer_norm<XPU_T, XPU_T>(xpu_ctx->x_context(),
|
||||
add_out_data,
|
||||
residual_data,
|
||||
out_data,
|
||||
m,
|
||||
n,
|
||||
epsilon,
|
||||
norm_weight_data,
|
||||
nullptr,
|
||||
nullptr,
|
||||
residual_out_data,
|
||||
nullptr,
|
||||
use_sdnn);
|
||||
PD_CHECK(ret == 0, "add_rms_layer_norm");
|
||||
} else {
|
||||
ret = api::rms_layer_norm<XPU_T, XPU_T>(xpu_ctx->x_context(),
|
||||
add_out_data,
|
||||
out_data,
|
||||
m,
|
||||
n,
|
||||
epsilon,
|
||||
norm_weight_data,
|
||||
nullptr,
|
||||
nullptr,
|
||||
false);
|
||||
PD_CHECK(ret == 0, "rms_layer_norm");
|
||||
}
|
||||
|
||||
return {out, residual_out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> RmsNorm(
|
||||
const paddle::Tensor& x,
|
||||
const paddle::optional<paddle::Tensor>& bias,
|
||||
const paddle::optional<paddle::Tensor>& residual,
|
||||
const paddle::Tensor& norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& norm_bias,
|
||||
const float epsilon,
|
||||
const int begin_norm_axis,
|
||||
const float quant_scale,
|
||||
const int quant_round_type,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound) {
|
||||
const auto x_type = x.dtype();
|
||||
|
||||
#define APPLY_RMS_NORM_KERNEL(TX) \
|
||||
return RmsNormKernel<TX>(x, \
|
||||
bias, \
|
||||
residual, \
|
||||
norm_weight, \
|
||||
norm_bias, \
|
||||
epsilon, \
|
||||
begin_norm_axis, \
|
||||
quant_scale, \
|
||||
quant_round_type, \
|
||||
quant_max_bound, \
|
||||
quant_min_bound);
|
||||
|
||||
if (x_type == paddle::DataType::BFLOAT16) {
|
||||
APPLY_RMS_NORM_KERNEL(paddle::bfloat16);
|
||||
} else if (x_type == paddle::DataType::FLOAT16) {
|
||||
APPLY_RMS_NORM_KERNEL(paddle::float16);
|
||||
} else if (x_type == paddle::DataType::FLOAT32) {
|
||||
APPLY_RMS_NORM_KERNEL(float);
|
||||
} else {
|
||||
PD_THROW("RmsNorm not support x_type=", static_cast<int>(x_type));
|
||||
return {};
|
||||
}
|
||||
#undef APPLY_RMS_NORM_KERNEL
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> RmsNormInferShape(
|
||||
const std::vector<int64_t>& x_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& residual_shape,
|
||||
const std::vector<int64_t>& norm_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& norm_bias_shape,
|
||||
const float epsilon,
|
||||
const int begin_norm_axis,
|
||||
const float quant_scale,
|
||||
const int quant_round_type,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound) {
|
||||
PD_CHECK(begin_norm_axis > 0 && begin_norm_axis <= x_shape.size(),
|
||||
"begin_norm_axis check fail");
|
||||
int64_t m = std::accumulate(x_shape.begin(),
|
||||
x_shape.begin() + begin_norm_axis,
|
||||
static_cast<int64_t>(1),
|
||||
std::multiplies<int64_t>());
|
||||
return {x_shape, x_shape, {m}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> RmsNormInferDtype(
|
||||
const paddle::DataType& x_dtype,
|
||||
const paddle::optional<paddle::DataType>& bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& residual_dtype,
|
||||
const paddle::DataType& norm_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& norm_bias_dtype,
|
||||
const float epsilon,
|
||||
const int begin_norm_axis,
|
||||
const float quant_scale,
|
||||
const int quant_round_type,
|
||||
const float quant_max_bound,
|
||||
const float quant_min_bound) {
|
||||
// out, residual_out
|
||||
return {x_dtype, x_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(fused_rms_norm_xpu)
|
||||
.Inputs({"x",
|
||||
paddle::Optional("bias"),
|
||||
paddle::Optional("residual"),
|
||||
"norm_weight",
|
||||
paddle::Optional("norm_bias")})
|
||||
.Outputs({"out", "residul_out"})
|
||||
.Attrs({"epsilon:float",
|
||||
"begin_norm_axis:int",
|
||||
"quant_scale:float",
|
||||
"quant_round_type:int",
|
||||
"quant_max_bound:float",
|
||||
"quant_min_bound:float"})
|
||||
.SetKernelFn(PD_KERNEL(RmsNorm))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(RmsNormInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(RmsNormInferDtype));
|
60
custom_ops/xpu_ops/src/ops/get_img_boundaries.cc
Normal file
60
custom_ops/xpu_ops/src/ops/get_img_boundaries.cc
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
std::vector<paddle::Tensor> GetImgBoundaries(const paddle::Tensor& task_input_ids,
|
||||
const paddle::Tensor& grid_thw,
|
||||
const int64_t image_patch_id) {
|
||||
// All tensor in cpu
|
||||
auto input_ids_ptr = task_input_ids.data<int64_t>();
|
||||
int64_t seq_lens_origin = task_input_ids.numel();
|
||||
auto grid_thw_ptr = grid_thw.data<int64_t>();
|
||||
|
||||
int token_times = 4;
|
||||
int token_idx = 0;
|
||||
int image_idx = 0;
|
||||
std::vector<int> img_boundaries, img_nums;
|
||||
img_boundaries.emplace_back(0);
|
||||
img_nums.emplace_back(0);
|
||||
while (token_idx < seq_lens_origin) {
|
||||
if (input_ids_ptr[token_idx] != image_patch_id) {
|
||||
do {
|
||||
token_idx++;
|
||||
} while (token_idx < seq_lens_origin && input_ids_ptr[token_idx] != image_patch_id);
|
||||
} else {
|
||||
int cur_image_token_len = (grid_thw_ptr[image_idx * 3 + 1] * grid_thw_ptr[image_idx * 3 + 2]) / token_times;
|
||||
image_idx++;
|
||||
token_idx += cur_image_token_len;
|
||||
}
|
||||
img_boundaries.emplace_back(token_idx);
|
||||
img_nums.emplace_back(image_idx);
|
||||
}
|
||||
|
||||
int64_t num_img_boundaries = static_cast<int64_t>(img_boundaries.size());
|
||||
auto out = paddle::full({2, num_img_boundaries}, 0, paddle::DataType::INT64, paddle::CPUPlace());
|
||||
|
||||
for (int i = 0; i < num_img_boundaries; i++) {
|
||||
out.data<int64_t>()[i] = img_boundaries[i];
|
||||
out.data<int64_t>()[num_img_boundaries + i] = img_nums[i];
|
||||
}
|
||||
|
||||
return {out};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(get_img_boundaries)
|
||||
.Inputs({"task_input_ids", "grid_thw"})
|
||||
.Attrs({"image_patch_id: int64_t"})
|
||||
.Outputs({"img_boundaries"})
|
||||
.SetKernelFn(PD_KERNEL(GetImgBoundaries));
|
@@ -18,13 +18,35 @@
|
||||
#include <sys/ipc.h>
|
||||
#include <sys/msg.h>
|
||||
#include <sys/types.h>
|
||||
#include "msg_utils.h"
|
||||
|
||||
#define MAX_BSZ 256
|
||||
// #define GET_OUTPUT_DEBUG
|
||||
struct msgdata {
|
||||
long mtype;
|
||||
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
|
||||
};
|
||||
void GetOutputKVSignal(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
bool wait_flag) {
|
||||
int msg_queue_id = 1024 + rank_id;
|
||||
static struct msgdatakv msg_rcv;
|
||||
static key_t key = ftok("/opt/", msg_queue_id);
|
||||
static int msgid = msgget(key, IPC_CREAT | 0666);
|
||||
|
||||
int* out_data = const_cast<int*>(x.data<int>());
|
||||
int ret = -1;
|
||||
if (!wait_flag) {
|
||||
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT);
|
||||
} else {
|
||||
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, 0);
|
||||
}
|
||||
if (ret == -1) {
|
||||
out_data[0] = -1;
|
||||
out_data[1] = -1;
|
||||
return;
|
||||
}
|
||||
int encoder_count = msg_rcv.mtext[0];
|
||||
|
||||
for (int i = 0; i < encoder_count * 3 + 2; i++) {
|
||||
out_data[i] = msg_rcv.mtext[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void GetOutput(const paddle::Tensor &x, int64_t rank_id, bool wait_flag,
|
||||
int msg_queue_id) {
|
||||
|
119
custom_ops/xpu_ops/src/ops/moe_ep_combine.cc
Normal file
119
custom_ops/xpu_ops/src/ops/moe_ep_combine.cc
Normal file
@@ -0,0 +1,119 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <infer_ops.h>
|
||||
#include <xft_api.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/backends/xpu/enforce_xpu.h"
|
||||
#include "utility/debug.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
std::vector<paddle::Tensor> MoeEPCombineKernel(
|
||||
const paddle::Tensor&
|
||||
ffn_out, // expand_token_num * hidden_dim dtype is fp16/bf16
|
||||
const paddle::Tensor& moe_index, // token_num * topk dtype is int
|
||||
const paddle::Tensor&
|
||||
weights, // token_num * topk dtype is same as ffn_out
|
||||
int64_t recv_token_num,
|
||||
int64_t expand_token_num,
|
||||
int64_t hidden_dim,
|
||||
int64_t topk) {
|
||||
using XPU_T = typename XPUTypeTrait<T>::Type;
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
|
||||
auto combined_out = paddle::empty(
|
||||
{recv_token_num, hidden_dim}, ffn_out.dtype(), ffn_out.place());
|
||||
|
||||
const float* dequant_score = nullptr;
|
||||
int ret = infer_ops::moe_ep_ffn_post_fusion(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_T*>(ffn_out.data<T>()),
|
||||
moe_index.data<int32_t>(),
|
||||
reinterpret_cast<const XPU_T*>(weights.data<T>()),
|
||||
dequant_score,
|
||||
reinterpret_cast<XPU_T*>(combined_out.mutable_data<T>()),
|
||||
recv_token_num,
|
||||
hidden_dim,
|
||||
topk,
|
||||
expand_token_num);
|
||||
PD_CHECK(ret == 0);
|
||||
|
||||
return {combined_out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MoeEPCombine(const paddle::Tensor& ffn_out,
|
||||
const paddle::Tensor& moe_index,
|
||||
const paddle::Tensor& weights,
|
||||
const int recv_token_num,
|
||||
const int expand_token_num,
|
||||
const int hidden_dim,
|
||||
const int topk) {
|
||||
#define APPLY_KERNEL(TX) \
|
||||
return MoeEPCombineKernel<TX>(ffn_out, \
|
||||
moe_index, \
|
||||
weights, \
|
||||
recv_token_num, \
|
||||
expand_token_num, \
|
||||
hidden_dim, \
|
||||
topk);
|
||||
|
||||
const auto ffn_out_dtype = ffn_out.dtype();
|
||||
if (ffn_out_dtype == paddle::DataType::FLOAT16) {
|
||||
APPLY_KERNEL(paddle::float16);
|
||||
} else if (ffn_out_dtype == paddle::DataType::BFLOAT16) {
|
||||
APPLY_KERNEL(paddle::bfloat16);
|
||||
} else {
|
||||
PD_THROW("MoeEPCombine not support ffn_out_type==%d",
|
||||
static_cast<int>(ffn_out_dtype));
|
||||
return {};
|
||||
}
|
||||
|
||||
#undef APPLY_KERNEL
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeEPCombineInferShape(
|
||||
const std::vector<int64_t>& ffn_out_shape,
|
||||
const std::vector<int64_t>& moe_index_shape,
|
||||
const std::vector<int64_t>& weights_shape,
|
||||
const int recv_token_num,
|
||||
const int expand_token_num,
|
||||
const int hidden_dim,
|
||||
const int topk) {
|
||||
std::vector<int64_t> combined_out_shape = {recv_token_num, hidden_dim};
|
||||
return {combined_out_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoeEPCombineInferDtype(
|
||||
const paddle::DataType& ffn_out_dtype,
|
||||
const paddle::DataType& moe_index_dtype,
|
||||
const paddle::DataType& weights_dtype) {
|
||||
return {ffn_out_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(ep_moe_expert_combine)
|
||||
.Inputs({"ffn_out", "moe_index", "weights"})
|
||||
.Outputs({"combined_out"})
|
||||
.Attrs({"recv_token_num: int",
|
||||
"expand_token_num: int",
|
||||
"hidden_dim: int",
|
||||
"topk: int"})
|
||||
.SetKernelFn(PD_KERNEL(MoeEPCombine))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeEPCombineInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeEPCombineInferDtype));
|
201
custom_ops/xpu_ops/src/ops/moe_ep_dispatch.cc
Normal file
201
custom_ops/xpu_ops/src/ops/moe_ep_dispatch.cc
Normal file
@@ -0,0 +1,201 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <infer_ops.h>
|
||||
#include <infer_ops_eb.h>
|
||||
#include <xft_api.h>
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/backends/xpu/enforce_xpu.h"
|
||||
#include "utility/debug.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
#endif
|
||||
|
||||
template <typename TX, typename TY>
|
||||
std::vector<paddle::Tensor> EPMoeExpertDispatchKernel(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& topk_weights,
|
||||
const paddle::optional<paddle::Tensor>& input_scales,
|
||||
const std::vector<int>& token_nums_per_expert,
|
||||
const int64_t token_nums_this_rank) {
|
||||
using XPU_TX = typename XPUTypeTrait<TX>::Type;
|
||||
using XPU_TY = typename XPUTypeTrait<TY>::Type;
|
||||
phi::XPUPlace xpu_place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx =
|
||||
paddle::experimental::DeviceContextPool::Instance().Get(xpu_place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
|
||||
const auto input_type = input.dtype();
|
||||
auto m = input.dims()[0];
|
||||
auto n = input.dims()[1];
|
||||
const int64_t expert_num = token_nums_per_expert.size();
|
||||
const int topk = topk_ids.dims()[1];
|
||||
auto place = input.place();
|
||||
|
||||
auto block_num = xpu_ctx->x_context()->ncluster();
|
||||
paddle::Tensor permute_input;
|
||||
auto permute_indices_per_token =
|
||||
paddle::empty({m, topk}, paddle::DataType::INT32, place);
|
||||
auto expert_m = paddle::empty({expert_num}, paddle::DataType::INT32, place);
|
||||
auto recv_num_tokens_per_expert_list_cumsum =
|
||||
paddle::empty({expert_num + 1}, paddle::DataType::INT32, place);
|
||||
auto expand_input_scales =
|
||||
paddle::empty({token_nums_this_rank}, paddle::DataType::FLOAT32, place);
|
||||
const int64_t ep_size = 1;
|
||||
const int64_t ep_rank = 0;
|
||||
|
||||
if (std::is_same<TY, int8_t>::value) {
|
||||
permute_input =
|
||||
paddle::empty({token_nums_this_rank, n}, paddle::DataType::INT8, place);
|
||||
auto ret = infer_ops::moe_ffn_pre_sorted_quant_pe<XPU_TX, int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_TX*>(input.data<TX>()),
|
||||
topk_ids.data<int>(),
|
||||
input_scales.get_ptr()->data<float>(),
|
||||
nullptr,
|
||||
reinterpret_cast<int8_t*>(permute_input.data<int8_t>()),
|
||||
const_cast<int*>(permute_indices_per_token.data<int>()),
|
||||
const_cast<int*>(expert_m.data<int>()),
|
||||
const_cast<int*>(recv_num_tokens_per_expert_list_cumsum.data<int>()),
|
||||
expand_input_scales.data<float>(),
|
||||
m,
|
||||
n,
|
||||
expert_num,
|
||||
topk,
|
||||
block_num,
|
||||
token_nums_this_rank);
|
||||
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
|
||||
} else {
|
||||
permute_input = paddle::empty({token_nums_this_rank, n}, input_type, place);
|
||||
auto ret = infer_ops::moe_ep_ffn_pre_sorted<XPU_TX, int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_TX*>(input.data<TX>()),
|
||||
topk_ids.data<int>(),
|
||||
nullptr,
|
||||
reinterpret_cast<XPU_TX*>(permute_input.data<TX>()),
|
||||
const_cast<int*>(permute_indices_per_token.data<int>()),
|
||||
const_cast<int*>(expert_m.data<int>()),
|
||||
const_cast<int*>(recv_num_tokens_per_expert_list_cumsum.data<int>()),
|
||||
m,
|
||||
n,
|
||||
expert_num,
|
||||
topk,
|
||||
block_num,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
token_nums_this_rank);
|
||||
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
|
||||
}
|
||||
return {permute_input,
|
||||
permute_indices_per_token,
|
||||
recv_num_tokens_per_expert_list_cumsum,
|
||||
topk_weights,
|
||||
expand_input_scales};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> EPMoeExpertDispatch(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& topk_weights,
|
||||
const paddle::optional<paddle::Tensor>& input_scales,
|
||||
const std::vector<int>& token_nums_per_expert,
|
||||
const int token_nums_this_rank,
|
||||
const std::string quant_method) {
|
||||
#define APPLY_KERNEL(TX, TY) \
|
||||
return EPMoeExpertDispatchKernel<TX, TY>(input, \
|
||||
topk_ids, \
|
||||
topk_weights, \
|
||||
input_scales, \
|
||||
token_nums_per_expert, \
|
||||
token_nums_this_rank);
|
||||
|
||||
const auto input_dtype = input.dtype();
|
||||
if (input_dtype == paddle::DataType::FLOAT16 && quant_method == "w4a8") {
|
||||
APPLY_KERNEL(paddle::float16, int8_t);
|
||||
} else if (input_dtype == paddle::DataType::FLOAT16 &&
|
||||
quant_method != "w4a8") {
|
||||
APPLY_KERNEL(paddle::float16, paddle::float16);
|
||||
} else if (input_dtype == paddle::DataType::BFLOAT16 &&
|
||||
quant_method == "w4a8") {
|
||||
APPLY_KERNEL(paddle::bfloat16, int8_t);
|
||||
} else if (input_dtype == paddle::DataType::BFLOAT16 &&
|
||||
quant_method != "w4a8") {
|
||||
APPLY_KERNEL(paddle::bfloat16, paddle::bfloat16);
|
||||
} else {
|
||||
PD_THROW("EPMoeExpertDispatch not support input_dtype=",
|
||||
static_cast<int>(input_dtype),
|
||||
"quant_method=",
|
||||
quant_method);
|
||||
return {};
|
||||
}
|
||||
|
||||
#undef APPLY_KERNEL
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> EPMoeExpertDispatchInferShape(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::vector<int64_t>& topk_ids_shape,
|
||||
const std::vector<int64_t>& topk_weights_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& input_scales_shape,
|
||||
const std::vector<int>& token_nums_per_expert,
|
||||
const int token_nums_this_rank,
|
||||
const std::string quant_method) {
|
||||
const int m = input_shape[0];
|
||||
const int hidden_size = input_shape[input_shape.size() - 1];
|
||||
const int topk = topk_ids_shape[topk_ids_shape.size() - 1];
|
||||
const int expert_num = token_nums_per_expert.size();
|
||||
return {{token_nums_this_rank, hidden_size},
|
||||
{expert_num, m},
|
||||
{expert_num},
|
||||
{token_nums_this_rank},
|
||||
{token_nums_this_rank}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> EPMoeExpertDispatchInferDtype(
|
||||
const paddle::DataType& input_dtype,
|
||||
const paddle::DataType& topk_ids_dtype,
|
||||
const paddle::DataType& topk_weights_dtype,
|
||||
const paddle::optional<paddle::DataType>& input_scales_dtype,
|
||||
const std::vector<int>& token_nums_per_expert,
|
||||
const int token_nums_this_rank,
|
||||
const std::string quant_method) {
|
||||
auto output_dtype = input_dtype;
|
||||
if (quant_method == "w4a8") {
|
||||
output_dtype = paddle::DataType::INT8;
|
||||
}
|
||||
return {
|
||||
output_dtype,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::INT32,
|
||||
topk_weights_dtype,
|
||||
paddle::DataType::FLOAT32,
|
||||
};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(ep_moe_expert_dispatch)
|
||||
.Inputs(
|
||||
{"input", "topk_ids", "topk_weights", paddle::Optional("input_scales")})
|
||||
.Outputs({"permute_input",
|
||||
"permute_indices_per_token",
|
||||
"token_nums_per_expert_cumsum",
|
||||
"dst_weights",
|
||||
"expand_input_scales"})
|
||||
.Attrs({"token_nums_per_expert: std::vector<int>",
|
||||
"token_nums_this_rank: int",
|
||||
"quant_method: std::string"})
|
||||
.SetKernelFn(PD_KERNEL(EPMoeExpertDispatch))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(EPMoeExpertDispatchInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(EPMoeExpertDispatchInferDtype));
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user