mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
Compare commits
95 Commits
copilot/ad
...
Jason/expe
Author | SHA1 | Date | |
---|---|---|---|
![]() |
8614ca56ad | ||
![]() |
c35a21a99a | ||
![]() |
c8985727a6 | ||
![]() |
076c30cb0f | ||
![]() |
f8c6a354a1 | ||
![]() |
b176cba474 | ||
![]() |
dcf633c4d9 | ||
![]() |
213f15ef55 | ||
![]() |
bab779011c | ||
![]() |
e2b68b33c9 | ||
![]() |
8a506500f3 | ||
![]() |
1aab1c8d06 | ||
![]() |
94b6e7a341 | ||
![]() |
389c5dd3a2 | ||
![]() |
361104508e | ||
![]() |
0bfffdbc14 | ||
![]() |
f489c9f8ef | ||
![]() |
be98f6e950 | ||
![]() |
f75697c2d1 | ||
![]() |
1e86418c4a | ||
![]() |
5027ed7239 | ||
![]() |
25aa2d94aa | ||
![]() |
b6caf6e622 | ||
![]() |
d381fa8194 | ||
![]() |
d2ab369427 | ||
![]() |
2883746132 | ||
![]() |
2485333f71 | ||
![]() |
10768a4d79 | ||
![]() |
c64ceac34d | ||
![]() |
447297a7b5 | ||
![]() |
63d24b2210 | ||
![]() |
48f2ab3fb3 | ||
![]() |
749f074e44 | ||
![]() |
f06e3ee1fc | ||
![]() |
2f473ba966 | ||
![]() |
cce2410fad | ||
![]() |
d8985a7a21 | ||
![]() |
7d1b2bd732 | ||
![]() |
71a9127e13 | ||
![]() |
8f5397616f | ||
![]() |
ece070cf6b | ||
![]() |
d40a1046de | ||
![]() |
fa2369271d | ||
![]() |
8903f937f9 | ||
![]() |
1023a67765 | ||
![]() |
d43549953c | ||
![]() |
c7c1627456 | ||
![]() |
d6bf6de5e6 | ||
![]() |
38e734e183 | ||
![]() |
051e4a881c | ||
![]() |
b2bb37d7c0 | ||
![]() |
c6e2a37a95 | ||
![]() |
8d77c1cb51 | ||
![]() |
41cd3e24c9 | ||
![]() |
11b18e5ef0 | ||
![]() |
e2c764fd5a | ||
![]() |
2d975e16b0 | ||
![]() |
8915c8411d | ||
![]() |
77c1bd0813 | ||
![]() |
473cde779f | ||
![]() |
335d1c8e8f | ||
![]() |
173e4df982 | ||
![]() |
199f88ce1e | ||
![]() |
55ebe855c0 | ||
![]() |
deb7ad205f | ||
![]() |
e9f72df918 | ||
![]() |
8567ada09e | ||
![]() |
afcde19277 | ||
![]() |
d40d3a5a4f | ||
![]() |
b8d0f1c081 | ||
![]() |
8550e19008 | ||
![]() |
a0c03510c0 | ||
![]() |
fb1e0d6a87 | ||
![]() |
fbf0e9d2aa | ||
![]() |
8c0e7d6fe9 | ||
![]() |
b56b015d85 | ||
![]() |
1432e336d7 | ||
![]() |
9213a58a06 | ||
![]() |
87ef0f5d30 | ||
![]() |
abcd2148c0 | ||
![]() |
05b6591c23 | ||
![]() |
42402c80e9 | ||
![]() |
1968c65849 | ||
![]() |
37cb37b7f2 | ||
![]() |
f975f7de2f | ||
![]() |
174510180a | ||
![]() |
5cda326ba2 | ||
![]() |
a6c8f17431 | ||
![]() |
cd09384a14 | ||
![]() |
0f42771a84 | ||
![]() |
d1d063e4af | ||
![]() |
a86b35ab49 | ||
![]() |
0cdbc950b5 | ||
![]() |
2b0a745d57 | ||
![]() |
1953c7c759 |
1
.github/workflows/Codestyle-Check.yml
vendored
1
.github/workflows/Codestyle-Check.yml
vendored
@@ -5,6 +5,7 @@ on:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
- 'feature/*'
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
|
5
.github/workflows/_accuracy_test.yml
vendored
5
.github/workflows/_accuracy_test.yml
vendored
@@ -80,12 +80,14 @@ jobs:
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
@@ -99,7 +101,7 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT)
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
@@ -133,6 +135,7 @@ jobs:
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
|
6
.github/workflows/_base_test.yml
vendored
6
.github/workflows/_base_test.yml
vendored
@@ -80,12 +80,14 @@ jobs:
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
@@ -99,7 +101,7 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT)
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
@@ -134,7 +136,7 @@ jobs:
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-e "FD_FORCE_CHUNKED_PREFILL=1" \
|
||||
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
|
||||
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
-v "${CACHE_DIR}/.cache:/root/.cache" \
|
||||
|
5
.github/workflows/_logprob_test_linux.yml
vendored
5
.github/workflows/_logprob_test_linux.yml
vendored
@@ -70,12 +70,14 @@ jobs:
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
@@ -89,7 +91,7 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT)
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
@@ -123,6 +125,7 @@ jobs:
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-v "${MODEL_CACHE_DIR}:/MODELDATA" \
|
||||
-v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \
|
||||
|
5
.github/workflows/_pre_ce_test.yml
vendored
5
.github/workflows/_pre_ce_test.yml
vendored
@@ -81,12 +81,14 @@ jobs:
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
@@ -96,7 +98,7 @@ jobs:
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT)
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
@@ -134,6 +136,7 @@ jobs:
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-e "fd_wheel_url=${fd_wheel_url}" \
|
||||
--gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c '
|
||||
|
5
.github/workflows/_unit_test_coverage.yml
vendored
5
.github/workflows/_unit_test_coverage.yml
vendored
@@ -102,12 +102,14 @@ jobs:
|
||||
FD_API_PORT=$((42088 + DEVICE_PORT * 100))
|
||||
FD_ENGINE_QUEUE_PORT=$((42058 + DEVICE_PORT * 100))
|
||||
FD_METRICS_PORT=$((42078 + DEVICE_PORT * 100))
|
||||
FD_CACHE_QUEUE_PORT=$((42098 + DEVICE_PORT * 100))
|
||||
echo "Test ENV Parameter:"
|
||||
echo "========================================================="
|
||||
echo "FLASK_PORT=${FLASK_PORT}"
|
||||
echo "FD_API_PORT=${FD_API_PORT}"
|
||||
echo "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}"
|
||||
echo "FD_METRICS_PORT=${FD_METRICS_PORT}"
|
||||
echo "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}"
|
||||
echo "DEVICES=${DEVICES}"
|
||||
echo "========================================================="
|
||||
|
||||
@@ -117,7 +119,7 @@ jobs:
|
||||
touch "${CACHE_DIR}/gitconfig"
|
||||
fi
|
||||
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT)
|
||||
PORTS=($FLASK_PORT $FD_API_PORT $FD_ENGINE_QUEUE_PORT $FD_METRICS_PORT $FD_CACHE_QUEUE_PORT)
|
||||
LOG_FILE="./port_cleanup_$(date +%Y%m%d_%H%M%S).log"
|
||||
echo "==== LOG_FILE is ${LOG_FILE} ===="
|
||||
|
||||
@@ -156,6 +158,7 @@ jobs:
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
-e "FLASK_PORT=${FLASK_PORT}" \
|
||||
-e "FD_CACHE_QUEUE_PORT=${FD_CACHE_QUEUE_PORT}" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
-e "fd_wheel_url=${fd_wheel_url}" \
|
||||
-e "BASE_REF=${BASE_REF}" \
|
||||
|
1
.github/workflows/approve.yml
vendored
1
.github/workflows/approve.yml
vendored
@@ -5,6 +5,7 @@ on:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
- 'feature/*'
|
||||
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
21
.github/workflows/ce_job.yml
vendored
21
.github/workflows/ce_job.yml
vendored
@@ -6,6 +6,7 @@ on:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
- 'feature/experimental_feature*'
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
@@ -154,6 +155,7 @@ jobs:
|
||||
COMPILE_ARCH: "80,90"
|
||||
WITH_NIGHTLY_BUILD: OFF
|
||||
FD_VERSION: 0.0.0
|
||||
PADDLE_WHL_URL: ${{ needs.ce_job_pre_check.outputs.compile_use_paddle_whl_url }}
|
||||
|
||||
build_sm8689:
|
||||
name: BUILD_SM8689
|
||||
@@ -166,6 +168,7 @@ jobs:
|
||||
COMPILE_ARCH: "86,89"
|
||||
WITH_NIGHTLY_BUILD: OFF
|
||||
FD_VERSION: 0.0.0
|
||||
PADDLE_WHL_URL: ${{ needs.ce_job_pre_check.outputs.compile_use_paddle_whl_url }}
|
||||
|
||||
ce_upload_sm8090:
|
||||
environment: CodeSync
|
||||
@@ -175,14 +178,13 @@ jobs:
|
||||
env:
|
||||
AK: ${{ secrets.BOS_AK }}
|
||||
SK: ${{ secrets.BOS_SK }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
COMPILE_ARCH: "80,90"
|
||||
steps:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Wheel Info Show and Upload
|
||||
if: github.ref_name == 'develop' || github.ref_type == 'tag'
|
||||
run: |
|
||||
echo "The wheel is located at: ${{ needs.build_sm8090.outputs.wheel_path }}"
|
||||
wget -q --no-check-certificate ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
@@ -190,7 +192,7 @@ jobs:
|
||||
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/paddle-pipeline/FastDeploy_ActionCE${COMPILE_ARCH//,/_}/${branch_name}/${commit_id}
|
||||
target_path=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/${commit_id}
|
||||
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
@@ -199,11 +201,13 @@ jobs:
|
||||
python ${push_file} ${filename} ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name}
|
||||
echo "commit wheel url is ${WHEEL_PATH}"
|
||||
|
||||
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE${COMPILE_ARCH//,/_}/${branch_name}/latest
|
||||
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
|
||||
python ${push_file} ${filename} ${target_path_latest}
|
||||
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
|
||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${fd_wheel_name}
|
||||
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
|
||||
|
||||
ce_upload_sm8689:
|
||||
environment: CodeSync
|
||||
@@ -213,14 +217,13 @@ jobs:
|
||||
env:
|
||||
AK: ${{ secrets.BOS_AK }}
|
||||
SK: ${{ secrets.BOS_SK }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build.outputs.wheel_path }}
|
||||
FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8689.outputs.wheel_path }}
|
||||
COMPILE_ARCH: "86,89"
|
||||
steps:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Wheel Info Show and Upload
|
||||
if: github.ref_name == 'develop' || github.ref_type == 'tag'
|
||||
run: |
|
||||
echo "The wheel is located at: ${{ needs.build_sm8090.outputs.wheel_path }}"
|
||||
wget -q --no-check-certificate ${{ needs.build_sm8090.outputs.wheel_path }}
|
||||
@@ -228,7 +231,7 @@ jobs:
|
||||
|
||||
commit_id=${{ github.sha }}
|
||||
branch_name=${{ github.ref_name }}
|
||||
target_path=paddle-qa/paddle-pipeline/FastDeploy_ActionCE${COMPILE_ARCH//,/_}/${branch_name}/${commit_id}
|
||||
target_path=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/${commit_id}
|
||||
|
||||
wget -q --no-proxy --no-check-certificate https://paddle-qa.bj.bcebos.com/CodeSync/develop/PaddlePaddle/PaddleTest/tools/bos_tools.py
|
||||
push_file=$(realpath bos_tools.py)
|
||||
@@ -237,8 +240,10 @@ jobs:
|
||||
python ${push_file} ${filename} ${target_path}
|
||||
target_path_stripped="${target_path#paddle-qa/}"
|
||||
WHEEL_PATH=https://paddle-qa.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name}
|
||||
echo "commit wheel url is ${WHEEL_PATH}"
|
||||
|
||||
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE${COMPILE_ARCH//,/_}/${branch_name}/latest
|
||||
target_path_latest=paddle-qa/paddle-pipeline/FastDeploy_ActionCE/SM${COMPILE_ARCH//,/_}/${branch_name}/latest
|
||||
python ${push_file} ${filename} ${target_path_latest}
|
||||
target_path_stripped_latest="${target_path_latest#paddle-qa/}"
|
||||
WHEEL_PATH_LATEST=https://paddle-qa.bj.bcebos.com/${target_path_stripped_latest}/${fd_wheel_name}
|
||||
echo "latest wheel url is ${WHEEL_PATH_LATEST}"
|
||||
|
3
.github/workflows/ci_xpu.yml
vendored
3
.github/workflows/ci_xpu.yml
vendored
@@ -5,6 +5,7 @@ on:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
- 'feature/*'
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
@@ -77,7 +78,7 @@ jobs:
|
||||
-e "MODEL_PATH=/ssd3/model" \
|
||||
-e "http_proxy=$(git config --global --get http.proxy)" \
|
||||
-e "https_proxy=$(git config --global --get https.proxy)" \
|
||||
-e "no_proxy=bcebos.com" \
|
||||
-e "no_proxy=bcebos.com,mirrors.tuna.tsinghua.edu.cn,127.0.0.1,localhost" \
|
||||
-e "FD_API_PORT=${FD_API_PORT}" \
|
||||
-e "FD_ENGINE_QUEUE_PORT=${FD_ENGINE_QUEUE_PORT}" \
|
||||
-e "FD_METRICS_PORT=${FD_METRICS_PORT}" \
|
||||
|
2
.github/workflows/pr_build_and_test.yml
vendored
2
.github/workflows/pr_build_and_test.yml
vendored
@@ -2,7 +2,7 @@ name: PR Build and Test
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
branches: [develop, release/**]
|
||||
branches: [develop, release/**, feature/**]
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
|
@@ -140,8 +140,8 @@ void AppendAttentionKernel(
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_mask,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales : cache_k_dequant_scales,
|
||||
cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales : cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
@@ -273,11 +273,15 @@ void AppendAttentionKernel(
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
|
||||
meta_data,
|
||||
@@ -296,11 +300,15 @@ void AppendAttentionKernel(
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
|
@@ -32,14 +32,15 @@ template <typename T,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8=false>
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__global__ void multi_query_append_attention_c8_kernel(
|
||||
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
|
||||
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
|
||||
// head_dim]
|
||||
CacheT *__restrict__ cache_v,
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads]
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int *__restrict__ seq_lens,
|
||||
@@ -91,28 +92,30 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
T cache_k_scale_reg[num_frags_y * 4];
|
||||
T cache_v_scale_reg[num_frags_y * 2];
|
||||
if (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
|
||||
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
}
|
||||
|
||||
const uint32_t q_end =
|
||||
@@ -201,6 +204,17 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
|
||||
v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
|
||||
T* k_smem_scale_ptr = nullptr;
|
||||
T* v_smem_scale_ptr = nullptr;
|
||||
smem_t k_scale_smem;
|
||||
smem_t v_scale_smem;
|
||||
if constexpr (IsDynamicC8) {
|
||||
k_smem_scale_ptr = reinterpret_cast<T*>(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
|
||||
v_smem_scale_ptr = k_smem_scale_ptr + num_frags_z * 16;
|
||||
k_scale_smem.base = reinterpret_cast<b128_t*>(k_smem_scale_ptr);
|
||||
v_scale_smem.base = reinterpret_cast<b128_t*>(v_smem_scale_ptr);
|
||||
}
|
||||
|
||||
|
||||
const uint32_t num_iterations = div_up(
|
||||
@@ -261,6 +275,20 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
k_scale_smem,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
produce_v_blockwise_c8<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -278,14 +306,34 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_v_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
v_scale_smem,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
|
||||
#pragma unroll 1
|
||||
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_k_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
k_smem_scale_ptr,
|
||||
cache_k_scale_reg
|
||||
);
|
||||
}
|
||||
// s = qk
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
|
||||
&qo_smem,
|
||||
&q_smem_offset_r,
|
||||
&k_smem,
|
||||
@@ -318,6 +366,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
s_frag, o_frag, m_frag, d_frag);
|
||||
__syncthreads();
|
||||
|
||||
const int ori_kv_idx_base = kv_idx_base;
|
||||
kv_idx_base += num_frags_z * 16;
|
||||
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -335,9 +384,29 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
k_scale_smem,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_v_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
v_smem_scale_ptr,
|
||||
cache_v_scale_reg
|
||||
);
|
||||
}
|
||||
|
||||
// compute sfm*v
|
||||
compute_sfm_v_c8<num_frags_x,
|
||||
@@ -346,7 +415,9 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
BLOCK_SIZE,
|
||||
T,
|
||||
CacheT,
|
||||
is_scale_channel_wise, IsFP8>(
|
||||
is_scale_channel_wise,
|
||||
IsFP8,
|
||||
IsDynamicC8>(
|
||||
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
|
||||
__syncthreads();
|
||||
|
||||
@@ -366,6 +437,20 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_v_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
v_scale_smem,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
|
||||
}
|
||||
@@ -463,14 +548,15 @@ template <typename T,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool is_scale_channel_wise=false,
|
||||
bool IsFP8=false>
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
|
||||
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
|
||||
// head_dim]
|
||||
CacheT *__restrict__ cache_v,
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int *__restrict__ seq_lens,
|
||||
@@ -522,28 +608,30 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
}
|
||||
T cache_k_scale_reg[num_frags_y * 4];
|
||||
T cache_v_scale_reg[num_frags_y * 2];
|
||||
if (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
|
||||
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
}
|
||||
const uint32_t q_end =
|
||||
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
|
||||
@@ -634,6 +722,17 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
|
||||
v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
|
||||
T* k_smem_scale_ptr = nullptr;
|
||||
T* v_smem_scale_ptr = nullptr;
|
||||
smem_t k_scale_smem;
|
||||
smem_t v_scale_smem;
|
||||
if constexpr (IsDynamicC8) {
|
||||
k_smem_scale_ptr = reinterpret_cast<T*>(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
|
||||
v_smem_scale_ptr = k_smem_scale_ptr + NUM_WARP_KV * num_frags_z * 16;
|
||||
k_scale_smem.base = reinterpret_cast<b128_t*>(k_smem_scale_ptr);
|
||||
v_scale_smem.base = reinterpret_cast<b128_t*>(v_smem_scale_ptr);
|
||||
}
|
||||
|
||||
const uint32_t num_iterations = div_up(
|
||||
CAUSAL
|
||||
@@ -696,6 +795,20 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
k_scale_smem,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
produce_v_blockwise_c8<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -713,14 +826,34 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_v_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
v_scale_smem,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
#pragma unroll 1
|
||||
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_k_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
k_smem_scale_ptr,
|
||||
cache_k_scale_reg
|
||||
);
|
||||
}
|
||||
|
||||
// s = qk
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
|
||||
&qo_smem,
|
||||
&q_smem_offset_r,
|
||||
&k_smem,
|
||||
@@ -753,6 +886,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
s_frag, o_frag, m_frag, d_frag);
|
||||
__syncthreads();
|
||||
|
||||
const uint32_t ori_kv_idx_base = kv_idx_base;
|
||||
kv_idx_base += NUM_WARP_KV * num_frags_z * 16;
|
||||
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -770,9 +904,29 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
k_scale_smem,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_v_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
v_smem_scale_ptr,
|
||||
cache_v_scale_reg
|
||||
);
|
||||
}
|
||||
|
||||
// compute sfm * v
|
||||
compute_sfm_v_c8_iter_sq_bvec<num_frags_x,
|
||||
@@ -781,7 +935,9 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
BLOCK_SIZE,
|
||||
T,
|
||||
CacheT,
|
||||
is_scale_channel_wise, IsFP8>(
|
||||
is_scale_channel_wise,
|
||||
IsFP8,
|
||||
IsDynamicC8>(
|
||||
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
|
||||
__syncthreads();
|
||||
|
||||
@@ -801,6 +957,20 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_v_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
v_scale_smem,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
}
|
||||
wait_group<0>();
|
||||
@@ -895,7 +1065,8 @@ template <typename T,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool IsFP8=false>
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
void MultiQueryAppendC8Attention(
|
||||
const AppendAttnMetaData &meta_data,
|
||||
const paddle::Tensor &qkv,
|
||||
@@ -953,7 +1124,8 @@ void MultiQueryAppendC8Attention(
|
||||
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16;
|
||||
constexpr uint32_t smem_size =
|
||||
num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
|
||||
num_frags_z * 16 * sizeof(T) * 2;
|
||||
auto split_kv_kernel =
|
||||
multi_query_append_attention_c8_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
@@ -970,7 +1142,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false, IsFP8>;
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
if (is_scale_channel_wise) {
|
||||
split_kv_kernel =
|
||||
multi_query_append_attention_c8_kernel<NV_TYPE,
|
||||
@@ -988,7 +1162,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true, IsFP8>;
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(split_kv_kernel,
|
||||
@@ -1022,7 +1198,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false, IsFP8>;
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
if (is_scale_channel_wise) {
|
||||
nosplit_kv_kernel =
|
||||
multi_query_append_attention_c8_kernel<NV_TYPE,
|
||||
@@ -1040,7 +1218,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true, IsFP8>;
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(nosplit_kv_kernel,
|
||||
@@ -1218,7 +1398,8 @@ void MultiQueryAppendC8Attention(
|
||||
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2;
|
||||
constexpr uint32_t smem_size =
|
||||
num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
|
||||
NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2;
|
||||
auto split_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
@@ -1235,7 +1416,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false, IsFP8>;
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
if (is_scale_channel_wise) {
|
||||
split_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
@@ -1253,7 +1436,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true, IsFP8>;
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(split_kv_kernel,
|
||||
@@ -1295,7 +1480,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false, IsFP8>;
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
if (is_scale_channel_wise) {
|
||||
nosplit_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
@@ -1313,7 +1500,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true, IsFP8>;
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(nosplit_kv_kernel,
|
||||
@@ -1546,6 +1735,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out) {
|
||||
const auto token_num = meta_data.token_nums;
|
||||
@@ -1554,6 +1744,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const auto num_heads = meta_data.q_num_heads;
|
||||
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
|
||||
const auto head_dim = meta_data.head_dims;
|
||||
bool is_dynamic_cfp8 = cache_quant_type_str == "block_wise_fp8";
|
||||
|
||||
DISPATCH_CAUSAL(
|
||||
causal,
|
||||
@@ -1572,43 +1763,46 @@ void CascadeAppendAttentionC8Kernel(
|
||||
BLOCK_SIZE,
|
||||
{DISPATCH_BLOCKSHAPE_Q(
|
||||
block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, {
|
||||
MultiQueryAppendC8Attention<T,
|
||||
GROUP_SIZE,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
CAUSAL,
|
||||
BLOCK_SHAPE_Q,
|
||||
NUM_WARP_Q,
|
||||
OutT,
|
||||
ENABLE_PREFILL, IsFP8>(
|
||||
meta_data,
|
||||
qkv,
|
||||
cache_k,
|
||||
cache_v,
|
||||
attn_mask,
|
||||
cache_k_scale.get(),
|
||||
cache_v_scale.get(),
|
||||
shift_bias,
|
||||
smooth_weight,
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
is_decoder,
|
||||
stream,
|
||||
out);
|
||||
})})})})})})
|
||||
DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, {
|
||||
MultiQueryAppendC8Attention<T,
|
||||
GROUP_SIZE,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
CAUSAL,
|
||||
BLOCK_SHAPE_Q,
|
||||
NUM_WARP_Q,
|
||||
OutT,
|
||||
ENABLE_PREFILL,
|
||||
IsFP8,
|
||||
IsDynamicC8>(
|
||||
meta_data,
|
||||
qkv,
|
||||
cache_k,
|
||||
cache_v,
|
||||
attn_mask,
|
||||
cache_k_scale.get(),
|
||||
cache_v_scale.get(),
|
||||
shift_bias,
|
||||
smooth_weight,
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
is_decoder,
|
||||
stream,
|
||||
out);
|
||||
})})})})})})})
|
||||
}
|
||||
|
@@ -384,6 +384,105 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
|
||||
}
|
||||
}
|
||||
|
||||
template<SharedMemFillMode fill_mode,
|
||||
uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_kv_dynamic_scale_gmem2smem_async(
|
||||
smem_t kv_scale_smem,
|
||||
const int* block_table_now,
|
||||
const T* cache_kv_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end
|
||||
) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
if (tid < block_size / 8) {
|
||||
const T* cache_k_scale_now = cache_kv_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size + tid * 8;
|
||||
const int kv_idx_this_thread = kv_idx + tid * 8;
|
||||
kv_scale_smem.load_128b_async<fill_mode>(tid, cache_k_scale_now, kv_idx_this_thread < chunk_end);
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
if (tid < block_size / 8 * 2) {
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * tid / 8;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const int kv_idx_this_thread = kv_idx + tid * 8;
|
||||
const T* cache_k_scale_now = cache_kv_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size + tid % 8 * 8;
|
||||
kv_scale_smem.load_128b_async<fill_mode>(tid, cache_k_scale_now, kv_idx_this_thread < chunk_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_k_dynamic_scale_smem2reg(
|
||||
T* k_smem_scale,
|
||||
T* cache_k_reg
|
||||
) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
const uint32_t scale_idx = fz * 16 + row_id;
|
||||
cache_k_reg[fz * 2] = k_smem_scale[scale_idx];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
const uint32_t scale_idx = ty * 32 + fz * 16 + row_id;
|
||||
cache_k_reg[fz * 2] = k_smem_scale[scale_idx];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_v_dynamic_scale_smem2reg(
|
||||
T* v_smem_scale,
|
||||
T* cache_v_reg
|
||||
) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t row_id = tx % 4 * 2;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
const uint32_t scale_idx = fz * 16 + row_id;
|
||||
cache_v_reg[fz * 4] = v_smem_scale[scale_idx];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9];
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
const uint32_t row_id = tx % 4 * 2;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
const uint32_t scale_idx = ty * 32 + fz * 16 + row_id;
|
||||
cache_v_reg[fz * 4] = v_smem_scale[scale_idx];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <SharedMemFillMode fill_mode,
|
||||
uint32_t num_warps,
|
||||
uint32_t block_size,
|
||||
@@ -816,7 +915,8 @@ template <uint32_t num_frags_x,
|
||||
typename T,
|
||||
typename CacheT,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8=false>
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
||||
uint32_t* q_smem_offset_r,
|
||||
smem_t* k_smem,
|
||||
@@ -860,20 +960,27 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
const int scale_col = (ky * 2 + fy) * 4;
|
||||
b_frag_dq_T[0] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
const int scale_col = (ky * 2 + fy) * 4;
|
||||
b_frag_dq_T[0] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_k_scale[0];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_k_scale[0];
|
||||
b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
@@ -1093,7 +1200,9 @@ template <uint32_t num_frags_x,
|
||||
uint32_t block_size,
|
||||
typename T,
|
||||
typename CacheT,
|
||||
bool is_scale_channel_wise = false, bool IsFP8=false>
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__device__ __forceinline__ void compute_sfm_v_c8(
|
||||
smem_t* v_smem,
|
||||
uint32_t* v_smem_offset_r,
|
||||
@@ -1135,16 +1244,28 @@ __device__ __forceinline__ void compute_sfm_v_c8(
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
const int scale_col = (kz * 2 + fz) * 4;
|
||||
b_frag_dq_T[0] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
|
||||
@@ -1171,7 +1292,9 @@ template <uint32_t num_frags_x,
|
||||
uint32_t block_size,
|
||||
typename T,
|
||||
typename CacheT,
|
||||
bool is_scale_channel_wise = false, bool IsFP8=false>
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
smem_t* v_smem,
|
||||
uint32_t* v_smem_offset_r,
|
||||
@@ -1215,16 +1338,28 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
const int scale_col = (kz * 2 + fz) * 4;
|
||||
b_frag_dq_T[0] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
|
||||
|
@@ -103,6 +103,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -264,9 +265,10 @@ void CascadeAppendAttentionKernel(
|
||||
causal,
|
||||
is_decoder,
|
||||
enable_prefill,
|
||||
cache_quant_type_str,
|
||||
stream,
|
||||
out);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
} else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
|
||||
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
|
||||
qkv,
|
||||
cache_k,
|
||||
@@ -299,6 +301,7 @@ void CascadeAppendAttentionKernel(
|
||||
causal,
|
||||
is_decoder,
|
||||
enable_prefill,
|
||||
cache_quant_type_str,
|
||||
stream,
|
||||
out);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
|
@@ -120,7 +120,6 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
float row_variance =
|
||||
max(warp_m2 / head_size, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
|
||||
if (hi < num_heads) { // q
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
@@ -129,6 +128,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
}
|
||||
} else { // k
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
}
|
||||
@@ -629,6 +629,294 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=true>
|
||||
__global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
T* __restrict__ cache_k_scale,
|
||||
T* __restrict__ cache_v_scale,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d,
|
||||
const float rms_norm_eps) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
const int lane_id = tid % 32;
|
||||
const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + bid * max_blocks_per_seq;
|
||||
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
int cache_offset;
|
||||
if (head_idx < num_heads) {
|
||||
cache_offset = 0;
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
cache_offset = block_idx * kv_num_heads * block_size + (head_idx - num_heads) % kv_num_heads * block_size + block_offset;
|
||||
}
|
||||
T *cache_k_scale_now = cache_k_scale + cache_offset;
|
||||
T *cache_v_scale_now = cache_v_scale + cache_offset;
|
||||
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
|
||||
if (head_idx < num_heads) {
|
||||
// q
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadOutScaleT = AlignedVector<float, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
|
||||
LoadT src_vec;
|
||||
LoadBiasT out_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size;
|
||||
#pragma unroll
|
||||
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
|
||||
head_bias += 32 * VecSize) {
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(tmp1);
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(tmp2);
|
||||
}
|
||||
// qk norm
|
||||
if (q_norm_weight) {
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
LoadOutScaleT q_norm_vec;
|
||||
Load<float, VecSize>(&q_norm_weight[lane_id * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
|
||||
}
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
// k
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||
const uint32_t kv_head_idx = (head_idx - num_heads) % kv_num_heads;
|
||||
if (block_offset == 0) {
|
||||
// pad zero for this kv_head_idx for this block
|
||||
LoadPadKVT pad_cache_vec;
|
||||
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE;
|
||||
constexpr int num_token_each_time = 32 / num_vecs_per_head_dim;
|
||||
const uint32_t tgt_idx =
|
||||
(block_idx * kv_num_heads + kv_head_idx) * block_size * HeadDim +
|
||||
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
|
||||
for (int block_i = lane_id / num_vecs_per_head_dim;
|
||||
block_i < block_size;
|
||||
block_i += num_token_each_time) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
|
||||
&key_cache[tgt_idx + block_i * HeadDim]);
|
||||
}
|
||||
} else {
|
||||
const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE;
|
||||
const int num_token_each_time = 32 / num_vecs_per_head_dim;
|
||||
const uint32_t tgt_idx =
|
||||
(block_idx * kv_num_heads + kv_head_idx) * HeadDim * block_size +
|
||||
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
|
||||
for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim;
|
||||
block_i += num_token_each_time) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]);
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
constexpr int K_VEC_SIZE = 4;
|
||||
constexpr int HALF_K_VEC_SIZE = 2;
|
||||
using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>;
|
||||
using LoadKVT = AlignedVector<uint8_t, HALF_K_VEC_SIZE>;
|
||||
using LoadT = AlignedVector<T, HALF_K_VEC_SIZE>;
|
||||
using LoadBiasT = AlignedVector<T, HALF_K_VEC_SIZE>;
|
||||
using LoadOutScaleT = AlignedVector<float, HALF_K_VEC_SIZE>;
|
||||
using LoadEmbT = AlignedVector<float, 1>;
|
||||
LoadKVResT cache_vec;
|
||||
LoadT src_vec1, src_vec2;
|
||||
LoadBiasT out_vec1, out_vec2;
|
||||
LoadEmbT cos_emb_vec1, cos_emb_vec2;
|
||||
LoadEmbT sin_emb_vec1, sin_emb_vec2;
|
||||
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
|
||||
const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2;
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
|
||||
T scale = T(1.0f);
|
||||
const int k_head_idx = head_idx - num_heads;
|
||||
const int v_head_idx = head_idx - num_heads - kv_num_heads;
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
}
|
||||
|
||||
float input_left = static_cast<float>(src_vec1[0]);
|
||||
float input_right = static_cast<float>(src_vec1[1]);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec1[0] =
|
||||
static_cast<T>(tmp1);
|
||||
out_vec1[1] =
|
||||
static_cast<T>(tmp2);
|
||||
} else {
|
||||
out_vec1[0] = src_vec1[0];
|
||||
out_vec1[1] = src_vec1[1];
|
||||
}
|
||||
|
||||
// rope
|
||||
input_left = static_cast<float>(src_vec2[0]);
|
||||
input_right = static_cast<float>(src_vec2[1]);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec2[0] = static_cast<T>(tmp1);
|
||||
out_vec2[1] = static_cast<T>(tmp2);
|
||||
} else {
|
||||
out_vec2[0] = src_vec2[0];
|
||||
out_vec2[1] = src_vec2[1];
|
||||
}
|
||||
if (k_norm_weight) {
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
LoadOutScaleT k_norm_vec1, k_norm_vec2;
|
||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias], &k_norm_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias + 8], &k_norm_vec2);
|
||||
// qk norm
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
out_vec1[i] = static_cast<T>(static_cast<float>(out_vec1[i]) * row_inv_var * k_norm_vec1[i]);
|
||||
out_vec2[i] = static_cast<T>(static_cast<float>(out_vec2[i]) * row_inv_var * k_norm_vec2[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// reduce max, 1 head per warp
|
||||
T local_max = -INFINITY;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
local_max = __hmax(local_max, __habs(out_vec1[i]));
|
||||
local_max = __hmax(local_max, __habs(out_vec2[i]));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
|
||||
local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
|
||||
}
|
||||
|
||||
scale = __hdiv(448, local_max);
|
||||
|
||||
if (lane_id == 0) {
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
cache_k_scale_now[0] = __hdiv(1, scale);
|
||||
} else {
|
||||
cache_v_scale_now[0] = __hdiv(1, scale);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
cache_vec[i] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec1[i], max_bound, min_bound);
|
||||
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec2[i], max_bound, min_bound);
|
||||
}
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const int start_block_16 =
|
||||
block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8;
|
||||
const uint32_t tgt_cache_idx =
|
||||
block_idx * kv_num_heads * block_size * HeadDim +
|
||||
kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim +
|
||||
lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4;
|
||||
Store<uint8_t, K_VEC_SIZE>(cache_vec, &key_cache[tgt_cache_idx]);
|
||||
} else {
|
||||
const uint32_t base_tgt_cache_idx =
|
||||
block_idx * kv_num_heads * HeadDim * block_size +
|
||||
kv_head_idx * HeadDim * block_size +
|
||||
(lane_id / 4 * 16 + lane_id % 4 * 2) * block_size +
|
||||
block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32;
|
||||
const uint32_t tgt_cache_idx1 = base_tgt_cache_idx +
|
||||
block_offset % 8 / 2 * 4 // per 4
|
||||
+ block_offset % 16 / 8 * 2 // per 2
|
||||
+ block_offset % 2; // per 1
|
||||
const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size;
|
||||
const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16;
|
||||
const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size;
|
||||
value_cache[tgt_cache_idx1] = cache_vec[0];
|
||||
value_cache[tgt_cache_idx2] = cache_vec[1];
|
||||
value_cache[tgt_cache_idx3] = cache_vec[2];
|
||||
value_cache[tgt_cache_idx4] = cache_vec[3];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=false>
|
||||
__global__ void append_decode_cache_int8_rope_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
|
@@ -572,9 +572,40 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||
rms_norm_eps);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
dim3 grids(bsz, all_warps / num_warps);
|
||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
|
||||
q_norm_weight.get().data<float>(),
|
||||
k_norm_weight.get().data<float>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||
"append_decode_cache_rope_qk_norm just supports cache_quant_type none/block_wise_fp8");
|
||||
}
|
||||
} else {
|
||||
if (cache_quant_type_str == "none") {
|
||||
@@ -709,6 +740,37 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
dim3 grids(bsz, all_warps / num_warps);
|
||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
|
||||
nullptr,
|
||||
nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_decode_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
|
@@ -1232,6 +1232,411 @@ __global__ void append_write_cache_kv_c8_qkv(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
uint32_t num_frags_y,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t HEAD_DIM,
|
||||
uint32_t BLOCK_SIZE,
|
||||
uint32_t NUM_WARPS,
|
||||
bool is_need_kv_quant,
|
||||
bool IsFP8 = true>
|
||||
__global__ void append_write_cache_kv_c8_qkv_dynamic(
|
||||
uint8_t *__restrict__ cache_k,
|
||||
uint8_t *__restrict__ cache_v,
|
||||
const T *__restrict__ qkv_input,
|
||||
T *__restrict__ cache_k_scales, // [block_num, num_heads, block_size]
|
||||
T *__restrict__ cache_v_scales, // [block_num, num_heads, block_size]
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids,
|
||||
const int *__restrict__ seq_lens_this_time,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ batch_id_per_token,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_tables,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
constexpr uint32_t pad_len = BLOCK_SIZE;
|
||||
const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z;
|
||||
const T cache_k_scale = cache_k_scales[kv_head_idx];
|
||||
const T cache_v_scale = cache_v_scales[kv_head_idx];
|
||||
const uint32_t tid = threadIdx.x, wid = threadIdx.y;
|
||||
const uint32_t batch_id = batch_ids[btid];
|
||||
const uint32_t tile_id = tile_ids[btid];
|
||||
const uint32_t seq_len_this_time = seq_lens_this_time[batch_id];
|
||||
if (seq_len_this_time <= 0) {
|
||||
return;
|
||||
}
|
||||
const int *block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + batch_id * max_blocks_per_seq;
|
||||
|
||||
const uint32_t num_rows_per_block =
|
||||
NUM_WARPS * num_frags_z * 16; // BLOCK_SIZE
|
||||
const uint32_t start_len = seq_lens_decoder[batch_id];
|
||||
const uint32_t bf_pad_len = start_len % pad_len;
|
||||
const uint32_t start_len_pad = start_len - bf_pad_len;
|
||||
const uint32_t end_len = start_len + seq_len_this_time;
|
||||
|
||||
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
|
||||
int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]);
|
||||
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
|
||||
|
||||
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
|
||||
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = HEAD_DIM;
|
||||
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
__shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
__shared__ T v_scale_smem[BLOCK_SIZE];
|
||||
if (tile_start >= start_len) {
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||
// pad zero for this kv_head_idx for this block
|
||||
LoadPadKVT pad_cache_vec;
|
||||
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
||||
// reset k
|
||||
constexpr int num_vecs_per_head_k = HEAD_DIM / KV_VEC_SIZE;
|
||||
constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k;
|
||||
uint32_t tgt_idx =
|
||||
(block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM +
|
||||
tid % num_vecs_per_head_k * KV_VEC_SIZE;
|
||||
for (int block_i = tid / num_vecs_per_head_k;
|
||||
block_i < BLOCK_SIZE;
|
||||
block_i += num_token_each_time_k) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
|
||||
&cache_k[tgt_idx + block_i * HEAD_DIM]);
|
||||
}
|
||||
|
||||
// reset v
|
||||
const int num_vecs_per_head_v = BLOCK_SIZE / KV_VEC_SIZE;
|
||||
const int num_token_each_time_v = 32 / num_vecs_per_head_v;
|
||||
tgt_idx =
|
||||
(block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE +
|
||||
tid % num_vecs_per_head_v * KV_VEC_SIZE;
|
||||
for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM;
|
||||
block_i += num_token_each_time_v) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE]);
|
||||
}
|
||||
}
|
||||
smem_t k_smem(k_smem_ori);
|
||||
smem_t v_smem(v_smem_ori);
|
||||
|
||||
uint32_t kv_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
wid * num_frags_z * 16 + tid / 8, tid % 8); // 4 * 8 per warp
|
||||
|
||||
/*
|
||||
0 | 1
|
||||
2 | 3
|
||||
*/
|
||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
||||
|
||||
constexpr uint32_t num_frags_v = num_frags_y / NUM_WARPS;
|
||||
/*
|
||||
0 | 2
|
||||
1 | 3
|
||||
*/
|
||||
uint32_t v_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
tid % 16, wid * num_frags_v * 2 + tid / 16);
|
||||
|
||||
// load kv gmem to smem
|
||||
const uint32_t real_start_token_idx = start_token_idx - bf_pad_len +
|
||||
tile_id * num_rows_per_block +
|
||||
wid * num_frags_z * 16 + tid / 8;
|
||||
uint32_t k_read_idx = real_start_token_idx * kv_batch_stride +
|
||||
(num_heads + kv_head_idx) * kv_h_stride +
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
uint32_t v_read_idx = real_start_token_idx * kv_batch_stride +
|
||||
(num_heads + kv_num_heads + kv_head_idx) * kv_h_stride +
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 4; ++j) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y / 4;
|
||||
++fy) { // (num_frags_y * 16) / (8 * num_elems_per_128b<T>())
|
||||
if (chunk_start >= start_len && chunk_start < end_len) {
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
kv_smem_offset_w, qkv_input + k_read_idx, chunk_start < end_len);
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
kv_smem_offset_w, qkv_input + v_read_idx, chunk_start < end_len);
|
||||
}
|
||||
kv_smem_offset_w =
|
||||
k_smem.advance_offset_by_column<8>(kv_smem_offset_w, fy);
|
||||
k_read_idx += 8 * num_elems_per_128b<T>();
|
||||
v_read_idx += 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
kv_smem_offset_w =
|
||||
k_smem.advance_offset_by_row<4, num_vecs_per_head>(kv_smem_offset_w) -
|
||||
2 * num_frags_y;
|
||||
chunk_start += 4;
|
||||
k_read_idx +=
|
||||
4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b<T>();
|
||||
v_read_idx +=
|
||||
4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
commit_group();
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// reduce scale
|
||||
// 16 rows per warp
|
||||
uint32_t kv_reduce_frag[4];
|
||||
T *kv_reduce_frag_T = reinterpret_cast<T*>(kv_reduce_frag);
|
||||
|
||||
T k_local_max_value[num_frags_z * 2];
|
||||
T v_local_max_value[num_frags_z * 2];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_frags_z * 2; i++) {
|
||||
k_local_max_value[i] = -INFINITY;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_frags_z * 2; i++) {
|
||||
v_local_max_value[i] = -INFINITY;
|
||||
}
|
||||
const int num_kv_heads = gridDim.z;
|
||||
const int scale_offset = block_id * num_kv_heads * BLOCK_SIZE + kv_head_idx * BLOCK_SIZE;
|
||||
T *cache_k_scale_now = cache_k_scales + scale_offset;
|
||||
T *cache_v_scale_now = cache_v_scales + scale_offset;
|
||||
// k scale
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
// reduce per thread, 4 threads each row
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
k_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), k_local_max_value[fz * 2]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
k_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), k_local_max_value[fz * 2 + 1]);
|
||||
}
|
||||
k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
|
||||
}
|
||||
// reduce per row
|
||||
for (int i = 0; i < 2; i++) {
|
||||
T local_max_value = __habs(k_local_max_value[fz * 2 + i]);
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2));
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1));
|
||||
// used for quant
|
||||
k_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value);
|
||||
}
|
||||
// store
|
||||
if (tid % 4 == 0) {
|
||||
const int offset_now = wid * num_frags_z * 16 + tid / 4;
|
||||
// used for dequant
|
||||
if (tile_start + offset_now >= start_len) {
|
||||
if (tile_start + offset_now < end_len) {
|
||||
cache_k_scale_now[offset_now] = __hdiv(1, k_local_max_value[fz * 2]);
|
||||
} else {
|
||||
cache_k_scale_now[offset_now] = 0;
|
||||
}
|
||||
}
|
||||
if (tile_start + offset_now + 8 >= start_len) {
|
||||
if (tile_start + offset_now + 8 < end_len) {
|
||||
cache_k_scale_now[offset_now + 8] = __hdiv(1, k_local_max_value[fz * 2 + 1]);
|
||||
} else {
|
||||
cache_k_scale_now[offset_now + 8] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1
|
||||
}
|
||||
// v scale
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
// reduce per thread, 4 threads each row
|
||||
v_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
v_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), v_local_max_value[fz * 2]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
v_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), v_local_max_value[fz * 2 + 1]);
|
||||
}
|
||||
k_smem_offset_r = v_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
|
||||
}
|
||||
// reduce per row
|
||||
for (int i = 0; i < 2; i++) {
|
||||
T local_max_value = __habs(v_local_max_value[fz * 2 + i]);
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2));
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1));
|
||||
v_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value);
|
||||
}
|
||||
// store
|
||||
if (tid % 4 == 0) {
|
||||
const int offset_now = wid * num_frags_z * 16 + tid / 4;
|
||||
// used for dequant
|
||||
if (tile_start + offset_now >= start_len) {
|
||||
if (tile_start + offset_now < end_len) {
|
||||
cache_v_scale_now[offset_now] = __hdiv(1, v_local_max_value[fz * 2]);
|
||||
v_scale_smem[offset_now] = v_local_max_value[fz * 2];
|
||||
} else {
|
||||
cache_v_scale_now[offset_now] = 0;
|
||||
v_scale_smem[offset_now] = 0;
|
||||
}
|
||||
}
|
||||
if (tile_start + offset_now + 8 >= start_len) {
|
||||
if (tile_start + offset_now + 8 < end_len) {
|
||||
cache_v_scale_now[offset_now + 8] = __hdiv(1, v_local_max_value[fz * 2 + 1]);
|
||||
v_scale_smem[offset_now + 8] = v_local_max_value[fz * 2 + 1];
|
||||
} else {
|
||||
cache_v_scale_now[offset_now + 8] = 0;
|
||||
v_scale_smem[offset_now + 8] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// mask, quant, store
|
||||
using LoadKVT = AlignedVector<uint8_t, 4>;
|
||||
LoadKVT cache_vec1;
|
||||
LoadKVT cache_vec2;
|
||||
|
||||
uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4;
|
||||
uint32_t kv_frag[4];
|
||||
const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t write_b_stride = HEAD_DIM;
|
||||
const uint32_t write_d_stride = BLOCK_SIZE;
|
||||
uint32_t k_write_idx = block_id * write_n_stride +
|
||||
kv_head_idx * write_h_stride +
|
||||
(wid * num_frags_z * 16 + tid / 4) * write_b_stride +
|
||||
tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
uint32_t k_write_idx_now_z = k_write_idx + fz * 16 * write_b_stride;
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
uint32_t k_write_idx_now = k_write_idx_now_z +
|
||||
fy % 2 * 8 * write_b_stride +
|
||||
fy / 2 * 32; // + fy % 2 * 16;
|
||||
// load
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
|
||||
// quant
|
||||
T *k_frag_T = reinterpret_cast<T *>(kv_frag);
|
||||
if (bf_pad_len != 0) {
|
||||
Load<uint8_t, 4>(cache_k + k_write_idx_now, &cache_vec1);
|
||||
Load<uint8_t, 4>(cache_k + k_write_idx_now + 16, &cache_vec2);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t v_id = 0; v_id < 8; ++v_id) {
|
||||
uint8_t uint_quant_value;
|
||||
if (chunk_start_k + (v_id / 4) * 8 >= start_len &&
|
||||
chunk_start_k + (v_id / 4) * 8 < end_len) {
|
||||
uint_quant_value = QuantToC8<T, is_need_kv_quant, IsFP8>(k_local_max_value[fz * 2 + v_id / 4], k_frag_T[v_id], 127.0f, -127.0f);
|
||||
} else {
|
||||
uint_quant_value = 0;
|
||||
}
|
||||
if (bf_pad_len != 0) {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] |= uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id % 4] |= uint_quant_value;
|
||||
}
|
||||
} else {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] = uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id - 4] = uint_quant_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
// store
|
||||
Store<uint8_t, 4>(cache_vec1, cache_k + k_write_idx_now);
|
||||
Store<uint8_t, 4>(cache_vec2, cache_k + k_write_idx_now + 16);
|
||||
k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
|
||||
}
|
||||
k_smem_offset_r =
|
||||
k_smem.advance_offset_by_row<16, num_vecs_per_head>(k_smem_offset_r) -
|
||||
2 * num_frags_y;
|
||||
chunk_start_k += 16;
|
||||
}
|
||||
|
||||
uint32_t chunk_start_v = tile_start + tid % 4 * 2;
|
||||
uint32_t v_write_idx = block_id * write_n_stride +
|
||||
kv_head_idx * write_h_stride +
|
||||
(wid * num_frags_v * 16 + tid / 4) * write_d_stride +
|
||||
tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit
|
||||
const uint32_t num_frags_z_v = num_frags_z * NUM_WARPS;
|
||||
T v_scales[num_frags_z_v * 4];
|
||||
for (int v_i = 0; v_i < num_frags_z_v; v_i++) {
|
||||
const int offset = v_i * 16;
|
||||
const int t_offset = tid % 4 * 2;
|
||||
v_scales[v_i * 4] = v_scale_smem[offset + t_offset];
|
||||
v_scales[v_i * 4 + 1] = v_scale_smem[offset + t_offset + 1];
|
||||
v_scales[v_i * 4 + 2] = v_scale_smem[offset + t_offset + 8];
|
||||
v_scales[v_i * 4 + 3] = v_scale_smem[offset + t_offset + 9];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_v; ++fy) {
|
||||
uint32_t v_write_idx_now_v = v_write_idx + fy * 16 * write_d_stride;
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z_v; ++fz) {
|
||||
uint32_t v_write_idx_now = v_write_idx_now_v +
|
||||
fz % 2 * 8 * write_d_stride +
|
||||
fz / 2 * 32; // + fz % 2 * 16;
|
||||
// load
|
||||
v_smem.ldmatrix_m8n8x4_trans(v_smem_offset_r, kv_frag);
|
||||
// quant
|
||||
T *v_frag_T = reinterpret_cast<T *>(kv_frag);
|
||||
if (bf_pad_len != 0) {
|
||||
Load<uint8_t, 4>(cache_v + v_write_idx_now, &cache_vec1);
|
||||
Load<uint8_t, 4>(cache_v + v_write_idx_now + 16, &cache_vec2);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t v_id = 0; v_id < 8; ++v_id) {
|
||||
uint8_t uint_quant_value;
|
||||
if (chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 >= start_len &&
|
||||
chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 < end_len) {
|
||||
uint_quant_value = QuantToC8<T, is_need_kv_quant, IsFP8>(v_scales[fz * 4 + v_id % 4], v_frag_T[v_id], 127.0f, -127.0f);
|
||||
// store now
|
||||
} else {
|
||||
uint_quant_value = 0;
|
||||
}
|
||||
if (bf_pad_len != 0) {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] |= uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id % 4] |= uint_quant_value;
|
||||
}
|
||||
} else {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] = uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id % 4] = uint_quant_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
// store
|
||||
Store<uint8_t, 4>(cache_vec1, cache_v + v_write_idx_now);
|
||||
Store<uint8_t, 4>(cache_vec2, cache_v + v_write_idx_now + 16);
|
||||
chunk_start_v += 16;
|
||||
v_smem_offset_r =
|
||||
k_smem.advance_offset_by_row<16, num_vecs_per_head>(v_smem_offset_r);
|
||||
}
|
||||
v_smem_offset_r = k_smem.advance_offset_by_column<2>(
|
||||
v_smem_offset_r, wid * num_frags_v + fy) -
|
||||
16 * num_frags_z_v * num_vecs_per_head;
|
||||
chunk_start_v -= 16 * num_frags_z_v;
|
||||
}
|
||||
}
|
||||
|
||||
// Write Cache KV in Append
|
||||
template <typename T,
|
||||
uint32_t num_frags_y,
|
||||
@@ -2006,10 +2411,11 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
int num_blocks_x_cpu,
|
||||
int max_seq_len,
|
||||
bool is_scale_channel_wise,
|
||||
const bool is_fp8,
|
||||
const std::string& cache_quant_type,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *cache_k_out,
|
||||
paddle::Tensor *cache_v_out) {
|
||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
||||
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
||||
auto num_tokens = meta_data.token_nums;
|
||||
auto num_heads = meta_data.q_num_heads;
|
||||
@@ -2027,49 +2433,77 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
dim3 blocks(32, num_warps);
|
||||
|
||||
const uint32_t smem_size = (BLOCK_SIZE * HEAD_DIM) * sizeof(T) * 2;
|
||||
auto kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, false>;
|
||||
if (is_fp8) {
|
||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, true>;
|
||||
if (cache_quant_type != "block_wise_fp8") {
|
||||
auto kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, false>;
|
||||
if (cache_quant_type == "cache_fp8") {
|
||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, true>;
|
||||
}
|
||||
if (is_scale_channel_wise) {
|
||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
false>;
|
||||
}
|
||||
cudaFuncSetAttribute(
|
||||
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
|
||||
cache_v_out->data<uint8_t>(),
|
||||
qkv.data<T>(),
|
||||
cache_k_scale.data<T>(),
|
||||
cache_v_scale.data<T>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads);
|
||||
} else {
|
||||
auto kernel_fn = append_write_cache_kv_c8_qkv_dynamic<NV_TYPE,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, true>;
|
||||
cudaFuncSetAttribute(
|
||||
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
|
||||
cache_v_out->data<uint8_t>(),
|
||||
reinterpret_cast<const NV_TYPE*>(qkv.data<T>()),
|
||||
const_cast<NV_TYPE*>(reinterpret_cast<const NV_TYPE*>(cache_k_scale.data<T>())),
|
||||
const_cast<NV_TYPE*>(reinterpret_cast<const NV_TYPE*>(cache_v_scale.data<T>())),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads);
|
||||
}
|
||||
if (is_scale_channel_wise) {
|
||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
false>;
|
||||
}
|
||||
cudaFuncSetAttribute(
|
||||
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
|
||||
cache_v_out->data<uint8_t>(),
|
||||
qkv.data<T>(),
|
||||
cache_k_scale.data<T>(),
|
||||
cache_v_scale.data<T>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads);
|
||||
}
|
||||
|
||||
template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>
|
||||
|
@@ -167,7 +167,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
stream,
|
||||
key_cache_out,
|
||||
value_cache_out);
|
||||
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8") {
|
||||
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
|
||||
DISPATCH_HEAD_DIM(
|
||||
head_dim, HEAD_DIM, {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {
|
||||
CascadeAppendWriteCacheKVC8QKV<T, HEAD_DIM, BLOCK_SIZE>(
|
||||
@@ -187,7 +187,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
is_scale_channel_wise,
|
||||
cache_quant_type_str == "cache_fp8",
|
||||
cache_quant_type_str,
|
||||
stream,
|
||||
key_cache_out,
|
||||
value_cache_out);
|
||||
|
@@ -1000,7 +1000,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
|
||||
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8") {
|
||||
CascadeAppendWriteCacheKVC8QKV<data_t, 128, 64>(
|
||||
meta_data,
|
||||
*const_cast<paddle::Tensor*>(&key_cache),
|
||||
@@ -1018,7 +1018,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
kv_num_blocks_data,
|
||||
max_seq_len,
|
||||
false, // is_scale_channel_wise
|
||||
cache_quant_type == "cache_fp8", // is_fp8
|
||||
cache_quant_type,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
|
@@ -18,6 +18,168 @@
|
||||
#include "mma_tensor_op.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, int VecSize = 1, typename InT = T>
|
||||
__global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
|
||||
// head_size]
|
||||
T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ q_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
const float*
|
||||
qkv_out_scales, // [(num_heads + 2 * gqa_group_size) * head_size]
|
||||
const T* qkv_biases, // [num_head + 2 * gqa_group_size, dim_head]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int output_inner_dim,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
LoadInT src_vec;
|
||||
LoadFloat scale_vec;
|
||||
LoadT bias_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
LoadFloat tmp_vec;
|
||||
LoadFloat q_norm_vec;
|
||||
LoadFloat k_norm_vec;
|
||||
|
||||
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||
int64_t all_head_dim = elem_cnt / head_size;
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size;
|
||||
const int half_head_size = head_size / 2;
|
||||
for (int global_hi = global_warp_idx; global_hi < all_head_dim; global_hi += all_warp_num) {
|
||||
int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize;
|
||||
const int token_id = linear_index / hidden_size;
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
if (block_idx < 0) {
|
||||
printf(
|
||||
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
|
||||
"%d %d %d %d\n",
|
||||
block_idx,
|
||||
write_seq_id,
|
||||
ori_bi,
|
||||
seq_lens_decoder[ori_bi],
|
||||
token_id,
|
||||
cu_seqlens_q[ori_bi]);
|
||||
}
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
const int write_q_idx =
|
||||
token_id * output_inner_dim * head_size + hi * head_size + h_bias;
|
||||
|
||||
const int bias_idx = hi * head_size + h_bias;
|
||||
Load<InT, VecSize>(&qkv[linear_index], &src_vec);
|
||||
if (qkv_biases) {
|
||||
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
|
||||
}
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &scale_vec);
|
||||
}
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
if (qkv_out_scales) {
|
||||
input_left *= scale_vec[2 * i];
|
||||
input_right *= scale_vec[2 * i + 1];
|
||||
}
|
||||
if (qkv_biases) {
|
||||
input_left = input_left + static_cast<float>(bias_vec[2 * i]);
|
||||
input_right = input_right + static_cast<float>(bias_vec[2 * i + 1]);
|
||||
}
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
tmp_vec[2 * i] = tmp1;
|
||||
tmp_vec[2 * i + 1] = tmp2;
|
||||
} else {
|
||||
bias_vec[2 * i] = static_cast<T>(input_left);
|
||||
bias_vec[2 * i + 1] = static_cast<T>(input_right);
|
||||
}
|
||||
}
|
||||
if (hi < (num_heads + gqa_group_size)) {
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / head_size, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
if (hi < num_heads) {
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
} else {
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (hi < num_heads) {
|
||||
// write q
|
||||
Store<T, VecSize>(bias_vec, &q_out[write_q_idx]);
|
||||
} else {
|
||||
// write k/v
|
||||
const int kv_head_idx = (hi - num_heads) % gqa_group_size;
|
||||
const int tgt_idx = (block_idx * gqa_group_size * block_size * head_size +
|
||||
kv_head_idx * block_size * head_size +
|
||||
block_offset * head_size + h_bias);
|
||||
// write
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
Store<T, VecSize>(bias_vec, &key_cache[tgt_idx]);
|
||||
} else {
|
||||
Store<T, VecSize>(bias_vec, &value_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int VecSize = 4, int HeadDim = 128>
|
||||
__global__ void append_clear_cache_int8_block(
|
||||
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
|
||||
@@ -193,7 +355,8 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
@@ -253,8 +416,9 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
@@ -326,7 +490,8 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
@@ -390,8 +555,9 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * head_size + h_bias;
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2: emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
@@ -476,7 +642,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -522,8 +689,9 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
}
|
||||
@@ -583,10 +751,11 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
T scale;
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||
} else {
|
||||
scale = __ldg(&cache_v_scales[kv_head_idx]);
|
||||
@@ -708,7 +877,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -757,8 +927,9 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx_left],
|
||||
&left_out_scale_vec);
|
||||
@@ -853,10 +1024,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
|
||||
T scale;
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
|
||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
@@ -1088,7 +1260,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -1145,8 +1318,9 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
@@ -1235,10 +1409,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// &out_scale_vec2);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx], &scale_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx + 8], &scale_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
|
||||
@@ -1431,7 +1606,8 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -1581,10 +1757,11 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
&right_out_scale_vec2);
|
||||
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx],
|
||||
&left_scale_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx + 8],
|
||||
|
@@ -15,6 +15,78 @@
|
||||
#include "speculate_write_cache_with_rope_kernel.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
T* key_cache,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
const float* qkv_out_scales,
|
||||
const T* qkv_biases,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const bool rope_3d) {
|
||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||
const uint32_t elem_nums =
|
||||
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||
: token_num * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||
constexpr int HEAD_DIM = 128;
|
||||
|
||||
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
if (use_neox_style) {
|
||||
PD_THROW(
|
||||
"append_speculate_cache_rope_qk_norm not support neox rope yet");
|
||||
} else {
|
||||
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
|
||||
append_speculate_cache_T_rope_qk_norm_kernel<T, PackSize>
|
||||
<<<grid_size, block_dim, 0, stream>>>(qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales,
|
||||
qkv_biases,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
output_inner_dim,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
|
||||
// rope + write
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
@@ -39,7 +111,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style) {
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||
|
||||
const uint32_t elem_nums =
|
||||
@@ -73,7 +146,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_speculate_cache_rope_kernel<T, PackSize>
|
||||
<<<grid_size, threads_per_block, 0, stream>>>(
|
||||
@@ -96,7 +170,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,7 +200,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style) {
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
@@ -167,7 +243,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_speculate_cache_int8_rope_kernel<T, 4, 0, 128, QKV_TYPE, IsFP8>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
@@ -191,7 +268,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,7 +300,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style) {
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
@@ -266,7 +345,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_speculate_cache_int4_rope_kernel<T, 4>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
@@ -292,7 +372,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
template <typename T, typename QKV_TYPE>
|
||||
@@ -313,11 +394,15 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out) {
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
typedef cascade_attn_type_traits<T> traits_;
|
||||
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
|
||||
typedef typename traits_::type DataType_;
|
||||
@@ -342,142 +427,185 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||
}
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
append_speculate_cache_int8_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_speculate_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope_qk_norm(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
reinterpret_cast<const float*>(q_norm_weight.get().data<float>()),
|
||||
reinterpret_cast<const float*>(k_norm_weight.get().data<float>()),
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||
}
|
||||
|
||||
} else {
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, "
|
||||
"cache_int4_zp]");
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
append_speculate_cache_int8_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_speculate_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, "
|
||||
"cache_int4_zp]");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -500,11 +628,15 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
template void
|
||||
SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
@@ -526,11 +658,15 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -551,11 +687,15 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
|
||||
template void
|
||||
@@ -578,8 +718,12 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
@@ -35,8 +35,12 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
@@ -56,6 +56,7 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -103,5 +104,6 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -98,5 +99,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -100,5 +101,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -100,5 +101,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -99,5 +100,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -99,5 +100,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -441,6 +441,15 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
PD_THROW("not support the group_size", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \
|
||||
if (is_dynamic_cfp8) { \
|
||||
constexpr bool IsDynamicC8 = true; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
constexpr bool IsDynamicC8 = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
|
@@ -378,9 +378,11 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const int block_size);
|
||||
|
||||
|
||||
const paddle::optional<paddle::Tensor> &draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
|
||||
const int block_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
paddle::Tensor
|
||||
GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor,
|
||||
@@ -707,6 +709,22 @@ void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
|
||||
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& seq_lens_decoder);
|
||||
|
||||
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &step_draft_tokens,
|
||||
const paddle::Tensor &step_seq_lens_this_time,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const int block_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
void NgramMatch(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &input_ids_len,
|
||||
const paddle::Tensor &pre_ids,
|
||||
@@ -750,6 +768,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
@@ -763,7 +782,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_draft_token,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill);
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1);
|
||||
|
||||
|
||||
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
@@ -1228,6 +1248,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function");
|
||||
|
||||
m.def("speculate_schedule_cache",&SpeculateScheduleCache, "SpeculateScheduleCache function");
|
||||
|
||||
m.def("ngram_match", &NgramMatch, "ngram_match function");
|
||||
|
||||
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");
|
||||
|
@@ -39,9 +39,6 @@ void GetOutputTopK(const paddle::Tensor& x,
|
||||
int k,
|
||||
int64_t rank_id,
|
||||
bool wait_flag) {
|
||||
if (rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
static struct msgdata msg_rcv;
|
||||
int msg_queue_id = 1;
|
||||
|
@@ -33,6 +33,11 @@
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 3: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 3; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 6: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 6; \
|
||||
__VA_ARGS__ \
|
||||
@@ -448,137 +453,71 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
||||
auto place = input.place();
|
||||
const int gridx = min(132 * 8, num_rows);
|
||||
if (moe_quant_type == "w4a8") {
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_kernel<data_t, int8_t, 8><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<int8_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);
|
||||
} else if (num_experts_per_rank == 16) {
|
||||
permute_x_kernel<data_t, int8_t, 16><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<int8_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);
|
||||
}
|
||||
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
|
||||
permute_x_kernel<data_t, int8_t, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<int8_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);)
|
||||
} else if (moe_quant_type == "w4afp8") {
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_kernel<data_t, data_t_fp8, 8, 512><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t_fp8>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
448.0f,
|
||||
-448.0f
|
||||
);
|
||||
} else if (num_experts_per_rank == 16) {
|
||||
permute_x_kernel<data_t, data_t_fp8, 16, 512><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t_fp8>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
448.0f,
|
||||
-448.0f
|
||||
);
|
||||
}
|
||||
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
|
||||
permute_x_kernel<data_t, data_t_fp8, NUM_EXPERTS_PER_RANK, 512><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t_fp8>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
448.0f,
|
||||
-448.0f
|
||||
);)
|
||||
} else {
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_kernel<data_t, data_t, 8><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);
|
||||
} else if (num_experts_per_rank == 16) {
|
||||
permute_x_kernel<data_t, data_t, 16><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);
|
||||
}
|
||||
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
|
||||
permute_x_kernel<data_t, data_t, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);)
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -15,31 +15,72 @@
|
||||
#include "helper.h"
|
||||
|
||||
__global__ void recover_decode_task(bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
int thread_idx = threadIdx.x;
|
||||
if (thread_idx < bsz) {
|
||||
if(is_block_step[thread_idx] == true) {
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
if (block_table_now[step_seq_lens_decoder[thread_idx] / block_size] != -1) {
|
||||
// can be recovered for decoding
|
||||
is_block_step[thread_idx] = false;
|
||||
seq_lens_this_time[thread_idx]= 1;
|
||||
stop_flags[thread_idx] = false;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
|
||||
}
|
||||
// can be recovered for decoding
|
||||
is_block_step[thread_idx] = false;
|
||||
seq_lens_this_time[thread_idx]= 1;
|
||||
stop_flags[thread_idx] = false;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void recover_spec_decode_task(bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
int64_t *draft_tokens,
|
||||
const int64_t *step_draft_tokens,
|
||||
const int *step_seq_lens_this_time,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size,
|
||||
const int draft_tokens_len,
|
||||
const int num_extra_tokens) {
|
||||
int thread_idx = threadIdx.x;
|
||||
if (thread_idx < bsz) {
|
||||
if(is_block_step[thread_idx] == true) {
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
int max_possible_block_idx = (step_seq_lens_decoder[thread_idx] + num_extra_tokens) / block_size;
|
||||
max_possible_block_idx = min(max_possible_block_idx, block_num_per_seq);
|
||||
if (block_table_now[max_possible_block_idx] != -1) {
|
||||
// can be recovered for decoding
|
||||
int64_t *draft_tokens_now = draft_tokens + thread_idx * draft_tokens_len;
|
||||
const int64_t *step_draft_tokens_now = step_draft_tokens + thread_idx * draft_tokens_len;
|
||||
is_block_step[thread_idx] = false;
|
||||
seq_lens_this_time[thread_idx] = step_seq_lens_this_time[thread_idx];
|
||||
stop_flags[thread_idx] = false;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
|
||||
for (int i = 0; i < seq_lens_this_time[thread_idx]; i++) {
|
||||
draft_tokens_now[i] = step_draft_tokens_now[i];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
@@ -47,7 +88,11 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const int block_size) {
|
||||
const paddle::optional<paddle::Tensor> &draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
|
||||
const int block_size,
|
||||
const int max_draft_tokens) {
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(seq_lens_this_time.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
@@ -56,17 +101,38 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
#endif
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
recover_decode_task<<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
if (draft_tokens) {
|
||||
const int draft_tokens_len = draft_tokens.get_ptr()->shape()[1];
|
||||
recover_spec_decode_task<<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<int64_t *>(draft_tokens.get_ptr()->data<int64_t>()),
|
||||
step_draft_tokens.get_ptr()->data<int64_t>(),
|
||||
step_seq_lens_this_time.get_ptr()->data<int>(),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size,
|
||||
draft_tokens_len,
|
||||
max_draft_tokens * 2 + 1);
|
||||
|
||||
} else {
|
||||
recover_decode_task<<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(recover_decode_task)
|
||||
@@ -76,8 +142,11 @@ PD_BUILD_STATIC_OP(recover_decode_task)
|
||||
"seq_lens_decoder",
|
||||
"step_seq_lens_decoder",
|
||||
"block_tables",
|
||||
"is_block_step"})
|
||||
.Attrs({"block_size: int"})
|
||||
"is_block_step",
|
||||
paddle::Optional("draft_tokens"),
|
||||
paddle::Optional("step_draft_tokens"),
|
||||
paddle::Optional("step_seq_lens_this_time")})
|
||||
.Attrs({"block_size: int", "max_draft_tokens: int"})
|
||||
.Outputs({"seq_lens_this_time_out",
|
||||
"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
|
@@ -15,7 +15,48 @@
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
|
||||
|
||||
#define DISPATCH_BLOCKSIZE(BLOCK_SIZE, ...) \
|
||||
do { \
|
||||
constexpr int BlockSize = BLOCK_SIZE; \
|
||||
__VA_ARGS__; \
|
||||
} while (0)
|
||||
|
||||
#define DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, ...) \
|
||||
do { \
|
||||
if (truncate_first_token) { \
|
||||
constexpr bool TRUNCATE_FIRST_TOKEN = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool TRUNCATE_FIRST_TOKEN = false; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, ...) \
|
||||
do { \
|
||||
if (kvcache_scheduler_v1) { \
|
||||
constexpr bool KVCACHE_SCHEDULER_V1 = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool KVCACHE_SCHEDULER_V1 = false; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, ...) \
|
||||
do { \
|
||||
if (splitwise_prefill) { \
|
||||
constexpr bool SPLITWISE_PREFILL = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool SPLITWISE_PREFILL = false; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
template <int THREADBLOCK_SIZE, bool TRUNCATE_FIRST_TOKEN, bool KVCACHE_SCHEDULER_V1>
|
||||
__global__ void process_splitwise_prefill(
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
@@ -25,6 +66,7 @@ __global__ void process_splitwise_prefill(
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
@@ -58,7 +100,7 @@ __global__ void process_splitwise_prefill(
|
||||
stop_flags[tid] = false;
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
int position = seq_len_encoder;
|
||||
if (TRCUNCATE_FIRST_TOKEN) {
|
||||
if (TRUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder;
|
||||
} else {
|
||||
@@ -84,7 +126,7 @@ __global__ void process_splitwise_prefill(
|
||||
|
||||
|
||||
|
||||
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
|
||||
template <int THREADBLOCK_SIZE, bool TRUNCATE_FIRST_TOKEN, bool KVCACHE_SCHEDULER_V1>
|
||||
__global__ void draft_model_preprocess_kernel(
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
@@ -94,6 +136,7 @@ __global__ void draft_model_preprocess_kernel(
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
@@ -134,14 +177,26 @@ __global__ void draft_model_preprocess_kernel(
|
||||
base_model_draft_tokens_now[i] = -1;
|
||||
}
|
||||
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
batch_drop[tid] = true;
|
||||
stop_flags[tid] = true;
|
||||
// 1. process block_step situation
|
||||
// -- In v0 mode, block_step will drop mtp query.
|
||||
// -- In v1 mode, block_step will continue to infer.
|
||||
if constexpr(KVCACHE_SCHEDULER_V1) {
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
stop_flags[tid] = true;
|
||||
is_block_step[tid] = true;
|
||||
// Need to continue infer
|
||||
}
|
||||
} else {
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
batch_drop[tid] = true;
|
||||
stop_flags[tid] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// 2. process normal query, not in any special case.
|
||||
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
|
||||
not_stop_flag = 1;
|
||||
// 1. first token
|
||||
// prefill generation
|
||||
if (seq_lens_encoder[tid] > 0) {
|
||||
// Can be extended to first few tokens
|
||||
int seq_len_encoder = seq_lens_encoder[tid];
|
||||
@@ -149,14 +204,20 @@ __global__ void draft_model_preprocess_kernel(
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
pre_ids_now[0] = base_model_first_token;
|
||||
int position = seq_len_encoder;
|
||||
if (TRCUNCATE_FIRST_TOKEN) {
|
||||
if (TRUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder;
|
||||
} else {
|
||||
input_ids_now[position] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder + 1;
|
||||
}
|
||||
} else {
|
||||
} else { // decode generation
|
||||
if constexpr (KVCACHE_SCHEDULER_V1) {
|
||||
// 3. try to recover mtp infer in V1 mode
|
||||
if (!base_model_is_block_step[tid] && is_block_step[tid]) {
|
||||
is_block_step[tid] = false;
|
||||
}
|
||||
}
|
||||
if (stop_flags[tid]) {
|
||||
stop_flags[tid] = false;
|
||||
// TODO: check
|
||||
@@ -189,99 +250,8 @@ __global__ void draft_model_preprocess_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <bool TRCUNCATE_FIRST_TOKEN>
|
||||
void DispatchRunner(
|
||||
const cudaStream_t& stream,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool splitwise_prefill) {
|
||||
constexpr int BlockSize = 512;
|
||||
if (splitwise_prefill) {
|
||||
process_splitwise_prefill<BlockSize, TRCUNCATE_FIRST_TOKEN>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
} else {
|
||||
draft_model_preprocess_kernel<BlockSize, TRCUNCATE_FIRST_TOKEN>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
}
|
||||
}
|
||||
|
||||
void DispatchTokenMode(
|
||||
void DispatchRunner(
|
||||
const cudaStream_t &stream,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
@@ -291,6 +261,7 @@ void DispatchTokenMode(
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
@@ -310,75 +281,79 @@ void DispatchTokenMode(
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
if (truncate_first_token) {
|
||||
DispatchRunner<true>(
|
||||
stream,
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
} else {
|
||||
DispatchRunner<false>(
|
||||
stream,
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
}
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1) {
|
||||
DISPATCH_BLOCKSIZE(512, {
|
||||
DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, {
|
||||
DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, {
|
||||
DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, {
|
||||
if constexpr (SPLITWISE_PREFILL) {
|
||||
process_splitwise_prefill<BlockSize, TRUNCATE_FIRST_TOKEN, KVCACHE_SCHEDULER_V1>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
} else {
|
||||
draft_model_preprocess_kernel<BlockSize, TRUNCATE_FIRST_TOKEN, KVCACHE_SCHEDULER_V1>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& stop_flags,
|
||||
@@ -387,6 +362,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
@@ -400,7 +376,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int num_model_step,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1) {
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
int accept_tokens_len = accept_tokens.shape()[1];
|
||||
int input_ids_len = input_ids.shape()[1];
|
||||
@@ -412,36 +389,38 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
auto not_need_stop_gpu =
|
||||
not_need_stop.copy_to(seq_lens_this_time.place(), false);
|
||||
|
||||
DispatchTokenMode(
|
||||
cu_stream,
|
||||
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool*>(batch_drop.data<bool>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_decoder.data<int>(),
|
||||
base_model_step_idx.data<int64_t>(),
|
||||
base_model_stop_flags.data<bool>(),
|
||||
base_model_is_block_step.data<bool>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill);
|
||||
DispatchRunner(
|
||||
cu_stream,
|
||||
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool*>(is_block_step.data<bool>()),
|
||||
const_cast<bool*>(batch_drop.data<bool>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_decoder.data<int>(),
|
||||
base_model_step_idx.data<int64_t>(),
|
||||
base_model_stop_flags.data<bool>(),
|
||||
base_model_is_block_step.data<bool>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill,
|
||||
kvcache_scheduler_v1);
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
@@ -459,6 +438,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"seq_lens_decoder",
|
||||
"step_idx",
|
||||
"not_need_stop",
|
||||
"is_block_step",
|
||||
"batch_drop",
|
||||
"pre_ids",
|
||||
"accept_tokens",
|
||||
@@ -480,7 +460,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"not_need_stop_out",
|
||||
"batch_drop_out",
|
||||
"pre_ids_out"})
|
||||
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
|
||||
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool", "kvcache_scheduler_v1: bool"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
|
@@ -0,0 +1,176 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void speculate_schedula_cache(
|
||||
const int64_t *draft_tokens,
|
||||
int *block_tables,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *step_draft_tokens,
|
||||
int *step_seq_lens_this_time,
|
||||
int *accept_num,
|
||||
int64_t *accept_tokens,
|
||||
bool *is_block_step,
|
||||
bool *not_need_stop,
|
||||
const int64_t *stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
const int draft_tokens_len,
|
||||
const int accept_tokens_len,
|
||||
const int block_size,
|
||||
const int block_num_per_seq) {
|
||||
const int bid = threadIdx.x;
|
||||
int stop_flag_now_int = 0;
|
||||
if (bid < real_bsz) {
|
||||
if (!stop_flags[bid]) {
|
||||
const int64_t *draft_tokens_now = draft_tokens + bid * draft_tokens_len;
|
||||
int64_t *step_draft_tokens_now = step_draft_tokens + bid * draft_tokens_len;
|
||||
int *block_table_now = block_tables + bid * block_num_per_seq;
|
||||
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
|
||||
const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
|
||||
if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) {
|
||||
is_block_step[bid] = true;
|
||||
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
|
||||
seq_lens_this_time[bid] = 0;
|
||||
stop_flags[bid] = true;
|
||||
stop_flag_now_int = 1;
|
||||
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
|
||||
seq_lens_decoder[bid] = 0;
|
||||
accept_num[bid] = 0;
|
||||
for (int i = 0; i < accept_tokens_len; i++) {
|
||||
accept_tokens_now[i] = -1;
|
||||
}
|
||||
for (int i = 0; i < draft_tokens_len; i++) {
|
||||
step_draft_tokens_now[i] = draft_tokens_now[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
} else if (bid >= real_bsz && bid < max_bsz) {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
__syncthreads();
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// printf("stop_sum %d \n", stop_sum);
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &step_draft_tokens,
|
||||
const paddle::Tensor &step_seq_lens_this_time,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const int block_size,
|
||||
const int max_draft_tokens) {
|
||||
const int real_bsz = seq_lens_this_time.shape()[0];
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int accept_tokens_len = accept_tokens.shape()[1];
|
||||
const int draft_token_len = draft_tokens.shape()[1];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
|
||||
constexpr int BlockSize = 512;
|
||||
const int max_next_step_tokens = 2 * max_draft_tokens + 2;
|
||||
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
speculate_schedula_cache<BlockSize><<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
|
||||
draft_tokens.data<int64_t>(),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(step_draft_tokens.data<int64_t>()),
|
||||
const_cast<int *>(step_seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens,
|
||||
draft_token_len,
|
||||
accept_tokens_len,
|
||||
block_size,
|
||||
block_num_per_seq
|
||||
);
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_schedule_cache)
|
||||
.Inputs({"draft_tokens",
|
||||
"block_tables",
|
||||
"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_decoder",
|
||||
"step_seq_lens_decoder",
|
||||
"step_draft_tokens",
|
||||
"step_seq_lens_this_time",
|
||||
"accept_num",
|
||||
"accept_tokens",
|
||||
"is_block_step",
|
||||
"not_need_stop",
|
||||
"stop_nums"})
|
||||
.Attrs({"block_size: int", "max_draft_tokens: int"})
|
||||
.Outputs({"draft_tokens_out",
|
||||
"block_tables_out",
|
||||
"stop_flags_out",
|
||||
"seq_lens_this_time_out",
|
||||
"seq_lens_decoder_out",
|
||||
"step_seq_lens_decoder_out",
|
||||
"step_draft_tokens_out",
|
||||
"step_seq_lens_this_time_out",
|
||||
"accept_num_out",
|
||||
"accept_tokens_out",
|
||||
"is_block_step_out",
|
||||
"not_need_stop_out"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"block_tables", "block_tables_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
|
||||
{"step_draft_tokens", "step_draft_tokens_out"},
|
||||
{"step_seq_lens_this_time", "step_seq_lens_this_time_out"},
|
||||
{"accept_num", "accept_num_out"},
|
||||
{"accept_tokens", "accept_tokens_out"},
|
||||
{"is_block_step", "is_block_step_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateScheduleCache));
|
@@ -38,14 +38,20 @@ __device__ int64_t topp_sampling_kernel(const int64_t *candidate_ids,
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
float sum_scores = 0.0f;
|
||||
float rand_top_p = curand_uniform(dev_curand_states + tid) * topp;
|
||||
for (int i = 0; i < candidate_len; i++) {
|
||||
sum_scores += candidate_scores[i];
|
||||
}
|
||||
float tgt_topp = sum_scores < topp ? sum_scores : topp;
|
||||
|
||||
sum_scores = 0.0f;
|
||||
float rand_top_p = curand_uniform(dev_curand_states + tid) * tgt_topp;
|
||||
for (int i = 0; i < candidate_len; i++) {
|
||||
sum_scores += candidate_scores[i];
|
||||
if (rand_top_p <= sum_scores) {
|
||||
return candidate_ids[i];
|
||||
return candidate_ids[i];
|
||||
}
|
||||
}
|
||||
return candidate_ids[0];
|
||||
return candidate_ids[0];
|
||||
}
|
||||
|
||||
__global__ void setup_kernel(curandState_t *state, const uint64_t seed,
|
||||
|
@@ -467,6 +467,9 @@ __global__ void KeMatrixTopPBeamTopKFt(
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (top_p_value == 1.0 && actual_candidates_lens[token_id] == 0){
|
||||
actual_candidates_lens[token_id] = max_cadidate_len;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -75,12 +75,8 @@ void DisPatchW4AFp8Gemm(
|
||||
const int64_t K,
|
||||
cudaStream_t stream) {
|
||||
|
||||
int kBlockN = (max_tokens + 15) / 16 * 16;
|
||||
int kBlockN = 256;
|
||||
int TailN = 0;
|
||||
if (kBlockN > 256) {
|
||||
TailN = kBlockN % 256;
|
||||
kBlockN = 256;
|
||||
}
|
||||
if constexpr (std::is_same_v<OutputType, cutlass::bfloat16_t>) {
|
||||
GEMM_SWITCH_BF16(
|
||||
M, K, batch_size, token_padding_size, kBlockN, TailN,
|
||||
|
@@ -88,6 +88,8 @@ gemm_case = [
|
||||
[8192, 3584, 8, 2048], # eb45T ffn1
|
||||
[7168, 8192, 8, 0], # eb45T ffn2
|
||||
[7168, 8192, 8, 2048], # eb45T ffn2
|
||||
[1792, 8192, 64, 0], # eb45t ffn1
|
||||
[8192, 896, 64, 0], # eb45t ffn2
|
||||
]
|
||||
|
||||
dtype = ["BF16"]
|
||||
|
@@ -12,8 +12,8 @@ rm -rf "$THIRDPARTY_DIR"
|
||||
mkdir -p "$THIRDPARTY_DIR" || exit 1
|
||||
|
||||
if [ "$1" == "stable" ]; then
|
||||
version_xvllm="20250710"
|
||||
version_xtdk="3.2.40.1"
|
||||
version_xvllm="20250902"
|
||||
version_xtdk="3.4.40.1"
|
||||
else
|
||||
version_xvllm="latest"
|
||||
version_xtdk="latest"
|
||||
|
@@ -18,6 +18,13 @@ This project implements an efficient **Speculative Decoding** inference framewor
|
||||
- ⏳ Coming Soon: Support Chunk-prefill
|
||||
- ⏳ Coming Soon: Multi-layer MTP Layer
|
||||
|
||||
- **Decoding with Hybrid MTP and Ngram Methods(Hybrid-MTP-with-Ngram)**
|
||||
|
||||
- Overview: A hybrid method combining MTP and Ngram. First, MTP generates N draft tokens, then Ngram matching is used to supplement additional draft tokens.
|
||||
|
||||
- Use Cases: Suitable when higher draft token coverage is required, leveraging both MTP’s generation capability and the efficiency of Ngram matching.
|
||||
|
||||
|
||||
---
|
||||
|
||||
### Coming Soon
|
||||
@@ -132,7 +139,13 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--scheduler-password "scheduler_mtp" \
|
||||
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${path_to_mtp_model}"}' &
|
||||
```
|
||||
## Decoding with Hybrid MTP and Ngram Methods
|
||||
|
||||
When starting the service, you only need to modify the --speculative-config option.
|
||||
For example, use MTP to generate two draft tokens, and then append three additional draft tokens from Ngram matching:
|
||||
```
|
||||
--speculative-config '{"method": "mtp", "num_model_steps": 2, "mtp_strategy": "with_ngram", "num_speculative_tokens": 5, "model": "'$model_path'/mtp"}'
|
||||
```
|
||||
## 🧠 Using Ngram-Based Decoding
|
||||
This method uses an n-gram sliding window to match the prompt and generated tokens to predict draft tokens. It is particularly effective in scenarios with high input-output overlap (e.g., code completion, document search).
|
||||
|
||||
|
@@ -72,5 +72,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_USE_DEEP_GEMM":
|
||||
lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "0"))),
|
||||
|
||||
# Whether to use Machete for wint4 dense GEMM.
|
||||
"FD_USE_MACHETE": lambda: os.getenv("FD_USE_MACHETE", "0"),
|
||||
}
|
||||
```
|
||||
|
@@ -14,6 +14,9 @@
|
||||
- ⏳ 即将支持:兼容 Chunk Prefill
|
||||
- ⏳ 即将支持:多层 MTP layer
|
||||
|
||||
- **混合MTP、Ngram方法解码(Hybrid-MTP-with-Ngram)**
|
||||
- 方法概述:混合MTP与Ngram方法,先使用MTP产出N个草稿Token,再使用Ngram匹配补充草稿Token。
|
||||
- 使用场景:适合在需要更多草稿Token时使用,兼顾MTP生成能力与Ngram匹配的高效性。
|
||||
---
|
||||
|
||||
### ⏳ 规划中
|
||||
@@ -110,7 +113,12 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--scheduler-password "scheduler_mtp" \
|
||||
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": ""${path_to_mtp_model}"}' &
|
||||
```
|
||||
## 使用混合MTP、Ngram方法解码
|
||||
在启动服务时,只需改动 --speculative-config 即可。例如使用MTP产出两个DraftToken,再额外拼接三个Ngram匹配的DraftToken
|
||||
```
|
||||
--speculative-config '{"method": "mtp", "num_model_steps": 2, "mtp_strategy": "with_ngram" ,"num_speculative_tokens": 5, "model": "'$model_path'/mtp"}'
|
||||
|
||||
```
|
||||
## 🧠 使用 Ngram 解码
|
||||
该算法通过 n-gram 窗口从 prompt 和已生成的 Token 中进行匹配生成草稿 Token,适合输入和输出有很大 overlap 的场景,如代码续写、文档查询等。
|
||||
> 使用 4×H100;量化方式选择 WINT4
|
||||
|
@@ -71,5 +71,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# 是否使用DeepGemm后端的FP8 blockwise MoE.
|
||||
"FD_USE_DEEP_GEMM":
|
||||
lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "0"))),
|
||||
|
||||
# 是否使用 Machete 后端的 wint4 GEMM.
|
||||
"FD_USE_MACHETE": lambda: os.getenv("FD_USE_MACHETE", "0"),
|
||||
}
|
||||
```
|
||||
|
@@ -163,7 +163,7 @@ class CacheMessager:
|
||||
try:
|
||||
prefilled_step_idx_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
|
||||
prefilled_layer_name = f"splitwise_complete_prefilled_step_{self.dp_rank_id}.{self.gpu_id}"
|
||||
prefilled_layer_name = f"splitwise_complete_prefilled_layer_{self.dp_rank_id}.{self.gpu_id}"
|
||||
prefilled_step_name = f"splitwise_complete_prefilled_step_{self.dp_rank_id}.{self.gpu_id}"
|
||||
step_shm_value = IPCSignal(
|
||||
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}",
|
||||
|
@@ -257,7 +257,12 @@ class PrefixCacheManager:
|
||||
Check if num_blocks gpu blocks can be allocated.
|
||||
"""
|
||||
if len(self.gpu_free_block_list) < num_blocks:
|
||||
return False
|
||||
if self.cache_config.enable_prefix_caching:
|
||||
self.free_block_ids(num_blocks)
|
||||
if len(self.gpu_free_block_list) < num_blocks:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
return True
|
||||
|
||||
@@ -448,7 +453,7 @@ class PrefixCacheManager:
|
||||
"""
|
||||
return (input_token_num + block_size - 1) // block_size
|
||||
|
||||
def update_cache_blocks(self, task, block_size):
|
||||
def update_cache_blocks(self, task, block_size, num_computed_tokens):
|
||||
"""
|
||||
update cache blocks for a task.
|
||||
# TODO(chengyanfu): support async update
|
||||
@@ -459,12 +464,19 @@ class PrefixCacheManager:
|
||||
"""
|
||||
try:
|
||||
req_id = task.request_id
|
||||
num_cached_tokens = task.num_cached_tokens
|
||||
block_tables = task.block_tables
|
||||
|
||||
last_node, input_ids = self.cache_info[req_id]
|
||||
left_input_ids = input_ids[num_cached_tokens:]
|
||||
last_node, num_cached_tokens = self.cache_info[req_id]
|
||||
if isinstance(task.prompt_token_ids, np.ndarray):
|
||||
prompt_token_ids = task.prompt_token_ids.tolist()
|
||||
else:
|
||||
prompt_token_ids = task.prompt_token_ids
|
||||
input_ids = prompt_token_ids + task.output_token_ids
|
||||
can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size
|
||||
left_input_ids = input_ids[num_cached_tokens:can_cache_computed_tokens]
|
||||
gpu_extra_block_ids = block_tables[num_cached_tokens // block_size :]
|
||||
if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later
|
||||
self.leaf_req_map[last_node].remove(req_id)
|
||||
|
||||
with self.request_release_lock:
|
||||
current_time = time.time()
|
||||
@@ -480,7 +492,8 @@ class PrefixCacheManager:
|
||||
)
|
||||
self.req_leaf_map[req_id] = leaf_node
|
||||
self.leaf_req_map[leaf_node].add(req_id)
|
||||
self.cache_info[req_id] = (leaf_node, input_ids)
|
||||
self.cache_info[req_id] = (leaf_node, can_cache_computed_tokens)
|
||||
task.cached_block_num = can_cache_computed_tokens // block_size
|
||||
except Exception as e:
|
||||
logger.error(f"update_cache_blocks, error: {type(e)} {e}, {str(traceback.format_exc())}")
|
||||
raise e
|
||||
@@ -508,7 +521,11 @@ class PrefixCacheManager:
|
||||
hit_info["gpu_cache_blocks"] = 0
|
||||
hit_info["cpu_cache_blocks"] = 0
|
||||
self.metrics.req_count += 1
|
||||
input_ids = task.prompt_token_ids
|
||||
if isinstance(task.prompt_token_ids, np.ndarray):
|
||||
prompt_token_ids = task.prompt_token_ids.tolist()
|
||||
else:
|
||||
prompt_token_ids = task.prompt_token_ids
|
||||
input_ids = prompt_token_ids + task.output_token_ids
|
||||
req_id = task.request_id
|
||||
logger.info(f"request_match_blocks: start to allocate blocks for req_id {req_id}")
|
||||
input_token_num = len(input_ids)
|
||||
@@ -546,9 +563,6 @@ class PrefixCacheManager:
|
||||
"request_match_blocks: Not enough GPU memory to allocate cache for matched CPU Cache"
|
||||
)
|
||||
|
||||
# record request cache info
|
||||
self.cache_info[req_id] = (match_block_node, input_ids)
|
||||
|
||||
# 3. update metrics
|
||||
matched_token_num = gpu_match_token_num + cpu_match_token_num
|
||||
common_block_ids = match_gpu_block_ids + gpu_recv_block_ids
|
||||
@@ -571,6 +585,9 @@ class PrefixCacheManager:
|
||||
# set leaf node temporarily, then update it in update_cache_blocks
|
||||
self.req_leaf_map[req_id] = match_block_node
|
||||
self.leaf_req_map[match_block_node].add(req_id)
|
||||
# record request cache info
|
||||
self.cache_info[req_id] = (match_block_node, matched_token_num)
|
||||
task.cached_block_num = matched_token_num // block_size
|
||||
return common_block_ids, matched_token_num, hit_info
|
||||
except Exception as e:
|
||||
logger.error(f"request_match_blocks: request_block_ids: error: {type(e)} {e}")
|
||||
@@ -687,6 +704,11 @@ class PrefixCacheManager:
|
||||
"""
|
||||
return self.executor_pool.submit(self.release_block_ids, task)
|
||||
|
||||
def free_block_ids(self, need_block_num):
|
||||
self.free_block_ids_async(need_block_num)
|
||||
while (self.gpu_free_task_future is not None) and (not self.gpu_free_task_future.done()):
|
||||
time.sleep(0.001)
|
||||
|
||||
def release_block_ids(self, task):
|
||||
"""
|
||||
release block ids
|
||||
@@ -1108,15 +1130,6 @@ class PrefixCacheManager:
|
||||
node.req_id_set.add(req_id)
|
||||
node = node.parent
|
||||
|
||||
def decrease_request_share_count(self, req_id):
|
||||
"""
|
||||
Decrease node shared count
|
||||
"""
|
||||
node, input_ids = self.cache_info[req_id]
|
||||
while node != self.radix_tree_root:
|
||||
node.decrement_shared_count()
|
||||
node = node.parent
|
||||
|
||||
def build_path(
|
||||
self,
|
||||
req_id,
|
||||
|
@@ -62,6 +62,7 @@ class ErnieArchitectures:
|
||||
"""Helper class for ERNIE architecture check."""
|
||||
|
||||
ARCHITECTURES = {
|
||||
"Ernie4_5ForCausalLM", # 0.3B-PT
|
||||
"Ernie4_5_ForCausalLM",
|
||||
"Ernie4_5_MoeForCausalLM",
|
||||
"Ernie4_5_VLMoeForConditionalGeneration",
|
||||
@@ -131,6 +132,7 @@ class ModelConfig:
|
||||
self.eos_tokens_lens: int = 2
|
||||
self.lm_head_fp32: bool = False
|
||||
self.model_format = "auto"
|
||||
self.num_nextn_predict_layers = 0
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
@@ -154,9 +156,7 @@ class ModelConfig:
|
||||
if hasattr(self, "vision_config"):
|
||||
self.vision_config = PretrainedConfig.from_dict(self.vision_config)
|
||||
|
||||
self.ori_vocab_size = self.vocab_size
|
||||
if ErnieArchitectures.contains_ernie_arch(self.architectures):
|
||||
self.ori_vocab_size = args.get("ori_vocab_size", self.ori_vocab_size)
|
||||
self.ori_vocab_size = args.get("ori_vocab_size", self.vocab_size)
|
||||
|
||||
architectures = self.architectures[0]
|
||||
if MultimodalRegistry.contains_model(architectures):
|
||||
@@ -294,6 +294,8 @@ class ParallelConfig:
|
||||
self.engine_pid: Optional[int] = None
|
||||
# Do profile or not
|
||||
self.do_profile: bool = False
|
||||
# Use internode_ll_two_stage or not
|
||||
self.use_internode_ll_two_stage: bool = False
|
||||
|
||||
self.max_num_batched_tokens: int = 2048
|
||||
# splitwise role
|
||||
@@ -338,23 +340,29 @@ class ParallelConfig:
|
||||
else:
|
||||
self.pd_disaggregation_mode = "None"
|
||||
|
||||
def set_tp_group(self):
|
||||
def set_communicate_group(self):
|
||||
# different tp group id
|
||||
# prevent different tp_groups using the same group_id
|
||||
tp_gid_offset = envs.FD_TP_GROUP_GID_OFFSET
|
||||
dist.collective._set_custom_gid(self.data_parallel_rank + tp_gid_offset)
|
||||
|
||||
self.tp_group = dist.new_group(
|
||||
range(
|
||||
self.data_parallel_rank * self.tensor_parallel_size,
|
||||
(self.data_parallel_rank + 1) * self.tensor_parallel_size,
|
||||
)
|
||||
)
|
||||
dist.collective._set_custom_gid(None)
|
||||
# same ep group id
|
||||
# (TODO:gaoziyuan move this gid config to ep.py)
|
||||
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
|
||||
if self.enable_expert_parallel:
|
||||
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
|
||||
self.ep_group = dist.new_group(range(self.expert_parallel_size))
|
||||
dist.collective._set_custom_gid(None)
|
||||
|
||||
logger.info(
|
||||
f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}."
|
||||
)
|
||||
dist.collective._set_custom_gid(None)
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
@@ -580,6 +588,9 @@ class GraphOptimizationConfig:
|
||||
Thus this flag cannot be used together with splitting_ops."""
|
||||
self.full_cuda_graph: bool = True
|
||||
|
||||
""" Whether to use shared memory pool for multi capture_size """
|
||||
self.use_unique_memory_pool: bool = False
|
||||
|
||||
self.max_capture_size: int = None
|
||||
self.real_shape_to_captured_size: dict[int, int] = None
|
||||
# CINN Config ...
|
||||
@@ -829,6 +840,7 @@ class LoadConfig:
|
||||
load_strategy: Specifies the weight loading method when enabled:
|
||||
- 'ipc': Real-time IPC streaming with automatic resharding
|
||||
- 'ipc_snapshot': Load from disk snapshot of IPC weights
|
||||
- 'meta': Only model meta messages
|
||||
- None: No dynamic loading
|
||||
"""
|
||||
|
||||
@@ -839,7 +851,7 @@ class LoadConfig:
|
||||
self.load_choices: Union[str, LoadChoices] = LoadChoices.DEFAULT.value
|
||||
self.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1
|
||||
self.dynamic_load_weight: bool = False
|
||||
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot"]] = None
|
||||
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot", "meta", "normal"]] = "normal"
|
||||
for key, value in args.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
@@ -890,7 +902,7 @@ class CacheConfig:
|
||||
else:
|
||||
self.kv_cache_ratio = 0.75
|
||||
self.enc_dec_block_num = 0 if current_platform.is_iluvatar() else 2
|
||||
self.prealloc_dec_block_slot_num_threshold = 5
|
||||
self.prealloc_dec_block_slot_num_threshold = 12
|
||||
self.cache_dtype = "bfloat16"
|
||||
self.model_cfg = None
|
||||
self.enable_chunked_prefill = False
|
||||
@@ -1197,12 +1209,10 @@ class FDConfig:
|
||||
|
||||
num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
if num_ranks > self.max_chips_per_node:
|
||||
if num_ranks > self.max_chips_per_node and self.load_config.load_strategy != "meta":
|
||||
self.worker_num_per_node = self.max_chips_per_node
|
||||
nnode = ceil_div(num_ranks, self.worker_num_per_node)
|
||||
assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
|
||||
|
||||
# assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
|
||||
else:
|
||||
self.worker_num_per_node = num_ranks
|
||||
|
||||
@@ -1233,23 +1243,17 @@ class FDConfig:
|
||||
|
||||
self.paddle_commit_id = paddle.version.commit
|
||||
|
||||
if self.cache_config.enable_chunked_prefill:
|
||||
self.force_chunked_prefill = int(envs.FD_FORCE_CHUNKED_PREFILL)
|
||||
if (
|
||||
self.speculative_config is not None
|
||||
and self.speculative_config.method in ["mtp"]
|
||||
and not self.force_chunked_prefill
|
||||
):
|
||||
self.cache_config.enable_chunked_prefill = False
|
||||
|
||||
if self.max_num_batched_tokens is None:
|
||||
if self.cache_config.enable_chunked_prefill:
|
||||
self.max_num_batched_tokens = 2048
|
||||
else:
|
||||
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
|
||||
if int(envs.ENABLE_V1_KVCACHE_SCHEDULER):
|
||||
if paddle.is_compiled_with_xpu():
|
||||
self.max_num_batched_tokens = self.max_model_len
|
||||
else:
|
||||
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
|
||||
else:
|
||||
if self.cache_config.enable_chunked_prefill:
|
||||
self.max_num_batched_tokens = 2048
|
||||
else:
|
||||
self.max_num_batched_tokens = self.max_model_len
|
||||
|
||||
if self.long_prefill_token_threshold == 0:
|
||||
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
|
||||
@@ -1293,13 +1297,9 @@ class FDConfig:
|
||||
f"be less than or equal to max_num_partial_prefills: {self.max_num_partial_prefills}"
|
||||
)
|
||||
assert self.splitwise_role in ["mixed", "prefill", "decode"]
|
||||
# TODO(@wufeisheng): TP and EP need to be supported simultaneously.
|
||||
assert (self.parallel_config.tensor_parallel_size == 1 and self.parallel_config.expert_parallel_size >= 1) or (
|
||||
self.parallel_config.tensor_parallel_size >= 1 and self.parallel_config.expert_parallel_size == 1
|
||||
), "TP and EP cannot be enabled at the same time"
|
||||
|
||||
if not self.cache_config.enable_chunked_prefill:
|
||||
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
|
||||
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
assert self.max_num_batched_tokens >= self.max_model_len, (
|
||||
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
|
||||
f"should be larger than or equal to max_model_len: {self.max_model_len}"
|
||||
|
@@ -14,12 +14,15 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import fields as dataclass_fields
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
EarlyStopConfig,
|
||||
@@ -131,7 +134,7 @@ class EngineArgs:
|
||||
"""
|
||||
dynamic load weight
|
||||
"""
|
||||
load_strategy: str = "ipc_snapshot"
|
||||
load_strategy: str = "normal"
|
||||
"""
|
||||
dynamic load weight strategy
|
||||
"""
|
||||
@@ -162,8 +165,7 @@ class EngineArgs:
|
||||
"""
|
||||
Ratio of tokens to process in a block.
|
||||
"""
|
||||
|
||||
prealloc_dec_block_slot_num_threshold: int = 5
|
||||
prealloc_dec_block_slot_num_threshold: int = 12
|
||||
"""
|
||||
Token slot threshold for preallocating decoder blocks.
|
||||
"""
|
||||
@@ -198,6 +200,11 @@ class EngineArgs:
|
||||
Flag to enable the custom all-reduce kernel.
|
||||
"""
|
||||
|
||||
use_internode_ll_two_stage: bool = False
|
||||
"""
|
||||
Flag to use the internode_ll_two_stage kernel.
|
||||
"""
|
||||
|
||||
engine_worker_queue_port: str = "8002"
|
||||
"""
|
||||
Port for worker queue communication.
|
||||
@@ -243,7 +250,7 @@ class EngineArgs:
|
||||
Ports for rdma communication.
|
||||
"""
|
||||
|
||||
enable_chunked_prefill: bool = True
|
||||
enable_chunked_prefill: bool = False
|
||||
"""
|
||||
Flag to enable chunked prefilling.
|
||||
"""
|
||||
@@ -385,13 +392,27 @@ class EngineArgs:
|
||||
"""
|
||||
if not self.tokenizer:
|
||||
self.tokenizer = self.model
|
||||
if self.splitwise_role == "decode":
|
||||
self.enable_prefix_caching = False
|
||||
if self.speculative_config is not None:
|
||||
self.enable_prefix_caching = False
|
||||
if self.enable_mm:
|
||||
self.enable_prefix_caching = False
|
||||
if not current_platform.is_cuda():
|
||||
self.enable_prefix_caching = False
|
||||
if self.dynamic_load_weight:
|
||||
self.enable_prefix_caching = False
|
||||
if self.enable_logprob:
|
||||
if self.speculative_config is not None:
|
||||
raise NotImplementedError("Logprob does not support speculation_config.")
|
||||
if self.enable_expert_parallel:
|
||||
raise NotImplementedError("Logprob does not support enable_expert_parallel.")
|
||||
if not current_platform.is_cuda():
|
||||
raise NotImplementedError("Only CUDA platform supports logprob.")
|
||||
if self.splitwise_role != "mixed":
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||
if not current_platform.is_cuda():
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||
if self.guided_decoding_backend != "off":
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
@@ -613,6 +634,12 @@ class EngineArgs:
|
||||
default=EngineArgs.disable_custom_all_reduce,
|
||||
help="Flag to disable custom all-reduce.",
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--use-internode-ll-two-stage",
|
||||
action="store_true",
|
||||
default=EngineArgs.use_internode_ll_two_stage,
|
||||
help="Flag to use the internode_ll_two_stage kernel.",
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--max-num-seqs",
|
||||
type=int,
|
||||
@@ -685,7 +712,7 @@ class EngineArgs:
|
||||
cache_group.add_argument(
|
||||
"--prealloc-dec-block-slot-num-threshold",
|
||||
type=int,
|
||||
default=5,
|
||||
default=12,
|
||||
help="Number of token slot threadshold to allocate next blocks for decoding.",
|
||||
)
|
||||
|
||||
@@ -715,7 +742,7 @@ class EngineArgs:
|
||||
perf_group = parser.add_argument_group("Performance Tuning")
|
||||
perf_group.add_argument(
|
||||
"--enable-prefix-caching",
|
||||
action="store_true",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=EngineArgs.enable_prefix_caching,
|
||||
help="Flag to enable prefix caching.",
|
||||
)
|
||||
@@ -981,14 +1008,35 @@ class EngineArgs:
|
||||
|
||||
if not model_cfg.is_unified_ckpt and hasattr(model_cfg, "tensor_parallel_size"):
|
||||
self.tensor_parallel_size = model_cfg.tensor_parallel_size
|
||||
|
||||
speculative_cfg = self.create_speculative_config()
|
||||
if not self.enable_chunked_prefill:
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and self.splitwise_role == "mixed"
|
||||
and (speculative_cfg is None or speculative_cfg.method not in ["mtp"])
|
||||
):
|
||||
# default enable chunked prefill
|
||||
self.enable_chunked_prefill = True
|
||||
|
||||
self.disable_chunked_prefill = int(envs.FD_DISABLE_CHUNKED_PREFILL)
|
||||
if self.disable_chunked_prefill:
|
||||
self.enable_chunked_prefill = False
|
||||
|
||||
if self.max_num_batched_tokens is None:
|
||||
if self.enable_chunked_prefill:
|
||||
self.max_num_batched_tokens = 2048
|
||||
else:
|
||||
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
|
||||
if int(envs.ENABLE_V1_KVCACHE_SCHEDULER):
|
||||
if paddle.is_compiled_with_xpu():
|
||||
self.max_num_batched_tokens = self.max_model_len
|
||||
else:
|
||||
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
|
||||
if speculative_cfg is not None and speculative_cfg.method is not None:
|
||||
self.max_num_batched_tokens = self.max_model_len
|
||||
else:
|
||||
self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
|
||||
else:
|
||||
if self.enable_chunked_prefill:
|
||||
self.max_num_batched_tokens = 2048
|
||||
else:
|
||||
self.max_num_batched_tokens = self.max_model_len
|
||||
|
||||
all_dict = asdict(self)
|
||||
all_dict["model_cfg"] = model_cfg
|
||||
@@ -996,7 +1044,6 @@ class EngineArgs:
|
||||
load_cfg = LoadConfig(all_dict)
|
||||
parallel_cfg = ParallelConfig(all_dict)
|
||||
scheduler_cfg = self.create_scheduler_config()
|
||||
speculative_cfg = self.create_speculative_config()
|
||||
graph_opt_cfg = self.create_graph_optimization_config()
|
||||
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
|
||||
moba_attention_config = self.create_moba_attention_config()
|
||||
|
@@ -33,16 +33,19 @@ from opentelemetry import trace
|
||||
from fastdeploy.engine.request import Request, RequestOutput
|
||||
from fastdeploy.engine.resource_manager import ResourceManager
|
||||
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
|
||||
from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.inter_communicator import (
|
||||
EngineCacheQueue,
|
||||
EngineWorkerQueue,
|
||||
IPCSignal,
|
||||
ZmqClient,
|
||||
ZmqIpcServer,
|
||||
ZmqTcpServer,
|
||||
)
|
||||
from fastdeploy.metrics.metrics import main_process_metrics
|
||||
from fastdeploy.metrics.trace_util import start_span, start_span_request
|
||||
from fastdeploy.model_executor.guided_decoding import schema_checker
|
||||
from fastdeploy.output.token_processor import TokenProcessor
|
||||
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
|
||||
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
|
||||
from fastdeploy.utils import EngineError, envs, llm_logger
|
||||
|
||||
@@ -62,7 +65,6 @@ class EngineSevice:
|
||||
self.cfg = cfg
|
||||
|
||||
self.scheduler = cfg.scheduler_config.scheduler()
|
||||
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.resource_manager = ResourceManagerV1(
|
||||
cfg.max_num_seqs,
|
||||
@@ -128,6 +130,17 @@ class EngineSevice:
|
||||
self.token_processor.tasks_queue = self.engine_worker_queue
|
||||
self.token_processor.run()
|
||||
|
||||
def create_data_processor(self):
|
||||
self.input_processor = InputPreprocessor(
|
||||
self.cfg.tokenizer,
|
||||
self.cfg.reasoning_parser,
|
||||
self.cfg.limit_mm_per_prompt,
|
||||
self.cfg.mm_processor_kwargs,
|
||||
self.cfg.model_config.enable_mm,
|
||||
self.cfg.tool_parser,
|
||||
)
|
||||
self.data_processor = self.input_processor.create_processor()
|
||||
|
||||
def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理
|
||||
current_suffix = int(self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id])
|
||||
llm_logger.info(f"current_suffix: {current_suffix}")
|
||||
@@ -225,7 +238,8 @@ class EngineSevice:
|
||||
client_id=0,
|
||||
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
|
||||
local_data_parallel_id=min(
|
||||
self.cfg.worker_num_per_node * self.cfg.node_rank + self.cfg.parallel_config.local_data_parallel_id,
|
||||
self.cfg.worker_num_per_node // self.cfg.parallel_config.tensor_parallel_size * self.cfg.node_rank
|
||||
+ self.cfg.parallel_config.local_data_parallel_id,
|
||||
self.cfg.parallel_config.data_parallel_size - 1,
|
||||
),
|
||||
)
|
||||
@@ -526,9 +540,14 @@ class EngineSevice:
|
||||
self.cfg.max_prefill_batch,
|
||||
)
|
||||
|
||||
self.resource_manager.check_and_free_block_tables()
|
||||
if self.cfg.model_config.enable_mm:
|
||||
self.resource_manager.check_and_free_block_tables()
|
||||
available_blocks = self.resource_manager.available_block_num()
|
||||
else:
|
||||
available_blocks = self.cfg.cache_config.max_block_num_per_seq
|
||||
|
||||
tasks = self.scheduler.get_requests(
|
||||
available_blocks=self.resource_manager.available_block_num(),
|
||||
available_blocks=available_blocks,
|
||||
block_size=self.cfg.cache_config.block_size,
|
||||
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
|
||||
max_num_batched_tokens=self.cfg.max_model_len,
|
||||
@@ -552,8 +571,6 @@ class EngineSevice:
|
||||
get_request_pool.submit(_fetch_request)
|
||||
# 2. Schedule requests
|
||||
tasks = self.resource_manager.schedule()
|
||||
main_process_metrics.num_requests_waiting.dec(len(tasks))
|
||||
main_process_metrics.num_requests_running.inc(len(tasks))
|
||||
# 3. Send to engine
|
||||
if tasks:
|
||||
self.resource_manager.get_real_bsz()
|
||||
@@ -568,10 +585,21 @@ class EngineSevice:
|
||||
def start_zmq_service(self, api_server_pid=None):
|
||||
if api_server_pid is None:
|
||||
return
|
||||
self.api_server_pid = api_server_pid
|
||||
self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL)
|
||||
self.zmq_server.start_server()
|
||||
self.zmq_server.create_router()
|
||||
|
||||
if envs.FD_ENABLE_INTERNAL_ADAPTER:
|
||||
self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL)
|
||||
self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER)
|
||||
self.external_adapter = InternalAdapter(
|
||||
cfg=self.cfg, engine=self, dp_rank=self.cfg.parallel_config.local_data_parallel_id
|
||||
)
|
||||
else:
|
||||
self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL)
|
||||
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
|
||||
self.recv_result_handle_thread = threading.Thread(
|
||||
target=self.send_response_server.recv_result_handle, daemon=True
|
||||
)
|
||||
self.recv_result_handle_thread.start()
|
||||
|
||||
time.sleep(3)
|
||||
self.insert_task_to_scheduler_thread = threading.Thread(target=self._insert_zmq_task_to_scheduler, daemon=True)
|
||||
self.insert_task_to_scheduler_thread.start()
|
||||
@@ -585,9 +613,9 @@ class EngineSevice:
|
||||
try:
|
||||
block = True if len(added_requests) == 0 else False
|
||||
if not self.cfg.model_config.enable_mm:
|
||||
err, data = self.zmq_server.receive_json_once(block)
|
||||
err, data = self.recv_request_server.receive_json_once(block)
|
||||
else:
|
||||
err, data = self.zmq_server.receive_pyobj_once(block)
|
||||
err, data = self.recv_request_server.receive_pyobj_once(block)
|
||||
if err is not None:
|
||||
llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}")
|
||||
break
|
||||
@@ -641,13 +669,27 @@ class EngineSevice:
|
||||
)
|
||||
# Since the request is not in scheduler
|
||||
# Send result by zmq directly
|
||||
self.zmq_server.send_multipart(request_id, [error_result])
|
||||
self.send_response_server.send_response(request_id, [error_result])
|
||||
except Exception as e:
|
||||
llm_logger.error(
|
||||
f"Error happend while receving new request from zmq, details={e}, "
|
||||
f"traceback={traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def _decode_token(self, token_ids, req_id, is_end):
|
||||
delta_text = ""
|
||||
if envs.FD_ENABLE_RETURN_TEXT:
|
||||
delta_text, cum_tokens, _ = self.data_processor.ids2tokens(token_ids, req_id)
|
||||
if delta_text != "":
|
||||
prefix_offset = self.data_processor.decode_status[req_id][0]
|
||||
read_offset = self.data_processor.decode_status[req_id][1]
|
||||
token_ids = cum_tokens[prefix_offset:read_offset]
|
||||
else:
|
||||
token_ids = []
|
||||
if is_end:
|
||||
del self.data_processor.decode_status[req_id]
|
||||
return delta_text, token_ids
|
||||
|
||||
def _zmq_send_generated_tokens(self):
|
||||
"""
|
||||
Recieve output for zmq
|
||||
@@ -659,7 +701,24 @@ class EngineSevice:
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
for request_id, contents in results.items():
|
||||
self.zmq_server.send_multipart(request_id, contents)
|
||||
new_contents = []
|
||||
for content in contents:
|
||||
delta_text, token_ids = self._decode_token(
|
||||
token_ids=content.outputs.token_ids, req_id=request_id, is_end=content.finished
|
||||
)
|
||||
|
||||
content.outputs.token_ids = token_ids
|
||||
content.outputs.text = delta_text
|
||||
new_contents.append(content)
|
||||
|
||||
if len(token_ids) == 0:
|
||||
llm_logger.warning(
|
||||
f"current tokens need to accumulate, req_id: {request_id} {content.outputs.token_ids}"
|
||||
)
|
||||
|
||||
if len(new_contents):
|
||||
llm_logger.info(f"Send response for request id: {request_id}")
|
||||
self.send_response_server.send_response(request_id, new_contents)
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
|
||||
@@ -748,6 +807,19 @@ class EngineSevice:
|
||||
def check_and_free_block_tables(self):
|
||||
self.resource_manager.check_and_free_block_tables()
|
||||
|
||||
def clear_data(self):
|
||||
try:
|
||||
llm_logger.info("Clear Data: Start")
|
||||
self.token_processor.clear_data()
|
||||
self.engine_worker_queue.clear_data()
|
||||
self.send_response_server.req_dict.clear()
|
||||
self.recv_request_server.req_dict.clear()
|
||||
llm_logger.info("Clear Data: Successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Clear data error: {e}")
|
||||
return False
|
||||
|
||||
def _exit_sub_services(self):
|
||||
"""
|
||||
exit sub services
|
||||
|
@@ -37,7 +37,6 @@ from fastdeploy.engine.args_utils import EngineArgs
|
||||
from fastdeploy.engine.common_engine import EngineSevice
|
||||
from fastdeploy.engine.expert_service import start_data_parallel_service
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
|
||||
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
|
||||
|
||||
@@ -85,14 +84,6 @@ class LLMEngine:
|
||||
self.running = True
|
||||
self.is_started = False
|
||||
|
||||
self.input_processor = InputPreprocessor(
|
||||
cfg.tokenizer,
|
||||
cfg.reasoning_parser,
|
||||
cfg.limit_mm_per_prompt,
|
||||
cfg.mm_processor_kwargs,
|
||||
cfg.model_config.enable_mm,
|
||||
cfg.tool_parser,
|
||||
)
|
||||
self.engine = EngineSevice(cfg)
|
||||
|
||||
if self.cfg.cache_config.num_gpu_blocks_override is None:
|
||||
@@ -114,10 +105,9 @@ class LLMEngine:
|
||||
self.ipc_signal_suffix = self.cfg.engine_worker_queue_port[0]
|
||||
self._init_worker_signals()
|
||||
|
||||
self.data_processor = self.input_processor.create_processor()
|
||||
self.engine.data_processor = self.data_processor
|
||||
|
||||
self.engine.start()
|
||||
self.engine.create_data_processor()
|
||||
self.data_processor = self.engine.data_processor
|
||||
if api_server_pid is not None:
|
||||
llm_logger.info(f"Start zmq server, api_server_pid: {api_server_pid}")
|
||||
self.engine.start_zmq_service(api_server_pid)
|
||||
@@ -199,7 +189,7 @@ class LLMEngine:
|
||||
request.sampling_params = sampling_params
|
||||
request.preprocess_start_time = time.time()
|
||||
|
||||
request = self.data_processor.process_request(request, self.cfg.max_model_len, **kwargs)
|
||||
request = self.engine.data_processor.process_request(request, self.cfg.max_model_len, **kwargs)
|
||||
request.prompt_token_ids_len = len(request.prompt_token_ids)
|
||||
request.need_prefill_tokens = request.prompt_token_ids_len
|
||||
input_ids_len = request.prompt_token_ids_len
|
||||
@@ -210,9 +200,6 @@ class LLMEngine:
|
||||
request.get("max_tokens"),
|
||||
),
|
||||
)
|
||||
if request.get("reasoning_max_tokens") is None:
|
||||
default_reasoning_max_tokens = max(int(request.get("max_tokens") * 0.8), 1)
|
||||
request.set("reasoning_max_tokens", default_reasoning_max_tokens)
|
||||
min_tokens = request.get("min_tokens")
|
||||
if input_ids_len + min_tokens >= self.cfg.max_model_len:
|
||||
error_msg = (
|
||||
@@ -342,7 +329,8 @@ class LLMEngine:
|
||||
for p in self.cache_manager_processes:
|
||||
llm_logger.info(f"Killing cache manager process {p.pid}")
|
||||
try:
|
||||
os.killpg(p.pid, signal.SIGTERM)
|
||||
pgid = os.getpgid(p.pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
except Exception as e:
|
||||
console_logger.error(
|
||||
f"Error killing cache manager process {p.pid}: {e}, {str(traceback.format_exc())}"
|
||||
@@ -433,9 +421,9 @@ class LLMEngine:
|
||||
py_script = os.path.join(current_dir_path, worker_path)
|
||||
|
||||
ori_vocab_size = (
|
||||
len(self.data_processor.tokenizer.sp_model)
|
||||
if hasattr(self.data_processor.tokenizer, "sp_model")
|
||||
else len(self.data_processor.tokenizer.vocab)
|
||||
len(self.engine.data_processor.tokenizer.sp_model)
|
||||
if hasattr(self.engine.data_processor.tokenizer, "sp_model")
|
||||
else len(self.engine.data_processor.tokenizer.vocab)
|
||||
)
|
||||
|
||||
ports = ",".join(self.cfg.engine_worker_queue_port)
|
||||
@@ -454,8 +442,8 @@ class LLMEngine:
|
||||
f" --total_block_num {self.cfg.cache_config.total_block_num}"
|
||||
f" --block_size {self.cfg.cache_config.block_size}"
|
||||
f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}"
|
||||
f" --eos_tokens_lens {self.data_processor.eos_token_id_len}"
|
||||
f" --pad_token_id {self.data_processor.pad_token_id}"
|
||||
f" --eos_tokens_lens {self.engine.data_processor.eos_token_id_len}"
|
||||
f" --pad_token_id {self.engine.data_processor.pad_token_id}"
|
||||
f" --engine_pid {self.cfg.engine_worker_queue_port[0]}"
|
||||
f" --max_num_batched_tokens {self.cfg.max_num_batched_tokens}"
|
||||
f" --splitwise_role {self.cfg.splitwise_role}"
|
||||
@@ -482,6 +470,7 @@ class LLMEngine:
|
||||
"dynamic_load_weight": self.cfg.load_config.dynamic_load_weight,
|
||||
"disable_any_whitespace": self.cfg.disable_any_whitespace,
|
||||
"disable_custom_all_reduce": self.cfg.parallel_config.disable_custom_all_reduce,
|
||||
"use_internode_ll_two_stage": self.cfg.parallel_config.use_internode_ll_two_stage,
|
||||
"enable_logprob": self.cfg.model_config.enable_logprob,
|
||||
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
|
||||
}
|
||||
@@ -546,7 +535,7 @@ class LLMEngine:
|
||||
for result in self._get_generated_tokens(req_id):
|
||||
is_end = result.finished
|
||||
if stream and not is_end:
|
||||
processed = self.data_processor.process_response(result)
|
||||
processed = self.engine.data_processor.process_response(result)
|
||||
if processed is None:
|
||||
continue
|
||||
output = processed.to_dict()
|
||||
@@ -554,7 +543,7 @@ class LLMEngine:
|
||||
|
||||
# Exit loop if termination condition is met
|
||||
if is_end:
|
||||
processed = self.data_processor.process_response(result)
|
||||
processed = self.engine.data_processor.process_response(result)
|
||||
output = processed.to_dict()
|
||||
llm_logger.debug(f"Generate result: {output}")
|
||||
if not stream:
|
||||
|
@@ -80,6 +80,9 @@ class ExpertService:
|
||||
|
||||
start_time = time.time()
|
||||
self.engine.start()
|
||||
|
||||
if envs.FD_ENABLE_RETURN_TEXT:
|
||||
self.engine.create_data_processor()
|
||||
if ipc_signal_suffix is not None:
|
||||
self.api_server_pid = ipc_signal_suffix
|
||||
self.engine.start_zmq_service(ipc_signal_suffix)
|
||||
|
@@ -159,8 +159,6 @@ class SamplingParams:
|
||||
def __post_init__(self):
|
||||
if self.seed is None:
|
||||
self.seed = random.randint(0, 922337203685477580)
|
||||
if self.max_tokens is not None and self.reasoning_max_tokens is None:
|
||||
self.reasoning_max_tokens = max(int(self.max_tokens * 0.8), 1)
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
|
@@ -84,11 +84,14 @@ class ResourceManagerV1(ResourceManager):
|
||||
return len(request.block_tables) * self.config.cache_config.block_size
|
||||
|
||||
def get_new_block_nums(self, request: Request, num_new_tokens: int):
|
||||
self.check_and_free_block_tables()
|
||||
return (
|
||||
block_num = (
|
||||
request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1
|
||||
) // self.config.cache_config.block_size - len(request.block_tables)
|
||||
|
||||
if self.config.speculative_config.method is not None:
|
||||
block_num = min(block_num + 1, self.config.cache_config.max_block_num_per_seq)
|
||||
return block_num
|
||||
|
||||
def _prepare_prefill_task(self, request, new_token_num):
|
||||
request.prefill_start_index = request.num_computed_tokens
|
||||
request.prefill_end_index = request.num_computed_tokens + new_token_num
|
||||
@@ -119,10 +122,12 @@ class ResourceManagerV1(ResourceManager):
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
preempted_req.num_computed_tokens = 0
|
||||
self._free_blocks(preempted_req)
|
||||
preempted_req.prefill_block_num = None
|
||||
preempted_req.cached_block_num = 0
|
||||
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
|
||||
preempted_reqs.append(preempted_req)
|
||||
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
|
||||
main_process_metrics.num_requests_waiting.inc(1)
|
||||
main_process_metrics.num_requests_running.dec(1)
|
||||
if preempted_req == request:
|
||||
# No more request to preempt.
|
||||
can_schedule = False
|
||||
@@ -141,19 +146,36 @@ class ResourceManagerV1(ResourceManager):
|
||||
if not self.config.model_config.enable_mm:
|
||||
return num_new_tokens
|
||||
|
||||
request.with_image = False
|
||||
inputs = request.multimodal_inputs
|
||||
if inputs.get("patch_idx", None) is not None and inputs.get("patch_map", None) is not None:
|
||||
pre_end_idx = request.num_computed_tokens
|
||||
new_end_idx = pre_end_idx + num_new_tokens
|
||||
|
||||
prompt_token_ids_len = len(request.prompt_token_ids)
|
||||
assert prompt_token_ids_len == len(inputs["patch_idx"]), (prompt_token_ids_len, len(inputs["patch_idx"]))
|
||||
|
||||
# start
|
||||
start_patch_idx = inputs["patch_idx"][pre_end_idx]
|
||||
if pre_end_idx >= prompt_token_ids_len:
|
||||
start_patch_idx = inputs["patch_idx"][-1]
|
||||
else:
|
||||
start_patch_idx = inputs["patch_idx"][pre_end_idx]
|
||||
start_patch_map = inputs["patch_map"][start_patch_idx]
|
||||
request.image_start = start_patch_map["image_num"]
|
||||
request.video_start = start_patch_map["video_num"]
|
||||
request.audio_start = start_patch_map["audio_num"]
|
||||
|
||||
# end
|
||||
end_patch_idx = inputs["patch_idx"][new_end_idx]
|
||||
if new_end_idx >= prompt_token_ids_len:
|
||||
end_patch_idx = inputs["patch_idx"][-1]
|
||||
else:
|
||||
end_patch_idx = inputs["patch_idx"][new_end_idx]
|
||||
if request.prompt_token_ids[new_end_idx] in [
|
||||
inputs["image_end_id"],
|
||||
inputs["video_end_id"],
|
||||
inputs["audio_end_id"],
|
||||
]:
|
||||
end_patch_idx -= 1
|
||||
end_patch_map = inputs["patch_map"][end_patch_idx]
|
||||
end_modal_id = end_patch_map["modal_id"]
|
||||
if end_modal_id > 0:
|
||||
@@ -168,8 +190,6 @@ class ResourceManagerV1(ResourceManager):
|
||||
and inputs.get("image_patch_id", None) is not None
|
||||
and inputs.get("grid_thw", None) is not None
|
||||
):
|
||||
request.with_image = False
|
||||
|
||||
input_ids_lst = request.prompt_token_ids + request.output_token_ids
|
||||
input_ids = paddle.to_tensor(input_ids_lst, dtype="int64")
|
||||
input_ids = paddle.to_tensor(input_ids_lst, dtype="int64")
|
||||
@@ -265,14 +285,6 @@ class ResourceManagerV1(ResourceManager):
|
||||
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
|
||||
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
|
||||
request.num_computed_tokens = request.num_total_tokens - 1
|
||||
else: # prefill finished
|
||||
if (
|
||||
self.config.cache_config.enable_prefix_caching
|
||||
and request.get("prefill_block_num", None) is None
|
||||
):
|
||||
# update prefill cache blocks for prefix caching
|
||||
request.prefill_block_num = len(request.block_tables)
|
||||
self.cache_manager.update_cache_blocks(request, self.config.cache_config.block_size)
|
||||
if (
|
||||
self.allocated_slots(request) - request.num_total_tokens
|
||||
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
|
||||
@@ -322,18 +334,33 @@ class ResourceManagerV1(ResourceManager):
|
||||
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
||||
token_budget -= num_new_tokens
|
||||
request.num_computed_tokens += num_new_tokens
|
||||
if self.config.cache_config.enable_prefix_caching:
|
||||
self.cache_manager.update_cache_blocks(
|
||||
request, self.config.cache_config.block_size, request.num_computed_tokens
|
||||
)
|
||||
req_index += 1
|
||||
# schedule the WAITING requests.
|
||||
if not preempted_reqs:
|
||||
while self.waiting and token_budget > 0:
|
||||
if len(self.running) == self.max_num_seqs:
|
||||
break
|
||||
if self.config.model_config.enable_mm and self.exist_prefill(scheduled_reqs):
|
||||
if (self.config.model_config.enable_mm or paddle.is_compiled_with_xpu()) and self.exist_prefill(
|
||||
scheduled_reqs
|
||||
):
|
||||
break
|
||||
request = self.waiting[0]
|
||||
if request.status == RequestStatus.WAITING:
|
||||
# Enable prefix caching
|
||||
if self.config.cache_config.enable_prefix_caching:
|
||||
if (
|
||||
self.config.cache_config.enable_hierarchical_cache
|
||||
and self.cache_manager.num_cpu_blocks > 0
|
||||
):
|
||||
if not self.cache_manager.can_allocate_gpu_blocks(
|
||||
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
|
||||
// self.config.cache_config.block_size
|
||||
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
|
||||
break
|
||||
success = self.get_prefix_cached_blocks(request)
|
||||
if not success:
|
||||
self._free_blocks(request)
|
||||
@@ -352,7 +379,13 @@ class ResourceManagerV1(ResourceManager):
|
||||
request.schedule_start_time = time.time()
|
||||
token_budget -= num_new_tokens
|
||||
request.num_computed_tokens += num_new_tokens
|
||||
if self.config.cache_config.enable_prefix_caching:
|
||||
self.cache_manager.update_cache_blocks(
|
||||
request, self.config.cache_config.block_size, request.num_computed_tokens
|
||||
)
|
||||
request.status = RequestStatus.RUNNING
|
||||
main_process_metrics.num_requests_waiting.dec(1)
|
||||
main_process_metrics.num_requests_running.inc(1)
|
||||
allocated_position = self.get_available_position()
|
||||
request.idx = allocated_position
|
||||
self.tasks_list[allocated_position] = request
|
||||
@@ -367,6 +400,15 @@ class ResourceManagerV1(ResourceManager):
|
||||
request.num_total_tokens
|
||||
) # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct
|
||||
if self.config.cache_config.enable_prefix_caching:
|
||||
if (
|
||||
self.config.cache_config.enable_hierarchical_cache
|
||||
and self.cache_manager.num_cpu_blocks > 0
|
||||
):
|
||||
if not self.cache_manager.can_allocate_gpu_blocks(
|
||||
(request.need_prefill_tokens + self.config.cache_config.block_size - 1)
|
||||
// self.config.cache_config.block_size
|
||||
): # to prevent block allocation for matching in hierarchical cache and cause dead lock
|
||||
break
|
||||
success = self.get_prefix_cached_blocks(request)
|
||||
if not success:
|
||||
self._free_blocks(request)
|
||||
@@ -382,7 +424,13 @@ class ResourceManagerV1(ResourceManager):
|
||||
scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens))
|
||||
token_budget -= num_new_tokens
|
||||
request.num_computed_tokens += num_new_tokens
|
||||
if self.config.cache_config.enable_prefix_caching:
|
||||
self.cache_manager.update_cache_blocks(
|
||||
request, self.config.cache_config.block_size, request.num_computed_tokens
|
||||
)
|
||||
request.status = RequestStatus.RUNNING
|
||||
main_process_metrics.num_requests_waiting.dec(1)
|
||||
main_process_metrics.num_requests_running.inc(1)
|
||||
else:
|
||||
if self.config.cache_config.enable_prefix_caching:
|
||||
self._free_blocks(request)
|
||||
@@ -424,7 +472,7 @@ class ResourceManagerV1(ResourceManager):
|
||||
|
||||
matched_block_num = len(common_block_ids)
|
||||
no_cache_block_num = self.cache_manager.get_required_block_num(
|
||||
request.prompt_token_ids_len - matched_token_num,
|
||||
request.need_prefill_tokens - matched_token_num,
|
||||
self.config.cache_config.block_size,
|
||||
)
|
||||
|
||||
@@ -440,7 +488,7 @@ class ResourceManagerV1(ResourceManager):
|
||||
main_process_metrics.prefix_gpu_cache_token_num.inc(request.gpu_cache_token_num)
|
||||
main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num)
|
||||
|
||||
if matched_token_num == request.prompt_token_ids_len:
|
||||
if matched_token_num == request.need_prefill_tokens:
|
||||
request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size
|
||||
request.skip_allocate = True
|
||||
else:
|
||||
@@ -458,16 +506,8 @@ class ResourceManagerV1(ResourceManager):
|
||||
|
||||
def _free_blocks(self, request: Request):
|
||||
if self.config.cache_config.enable_prefix_caching:
|
||||
# TODO(chengyanfu): support cache ouput blocks for prefix caching
|
||||
if request.get("prefill_block_num", None) is None:
|
||||
leaf_node = self.cache_manager.req_leaf_map[request.request_id]
|
||||
self.cache_manager.decrease_request_share_count(request.request_id)
|
||||
self.cache_manager.free_nodes_directly(leaf_node)
|
||||
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.cache_info[0] :])
|
||||
|
||||
else:
|
||||
self.cache_manager.release_block_ids_async(request)
|
||||
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.prefill_block_num :])
|
||||
self.cache_manager.release_block_ids(request)
|
||||
self.cache_manager.recycle_gpu_blocks(request.block_tables[request.cached_block_num :])
|
||||
else:
|
||||
self.cache_manager.recycle_gpu_blocks(request.block_tables)
|
||||
request.block_tables = []
|
||||
@@ -508,3 +548,7 @@ class ResourceManagerV1(ResourceManager):
|
||||
del self.requests[req_id]
|
||||
except Exception as e:
|
||||
llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}")
|
||||
|
||||
def clear_data(self):
|
||||
self.waiting: deque[Request] = deque()
|
||||
self.to_be_rescheduled_request_id_set = set()
|
||||
|
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
@@ -26,7 +27,7 @@ from fastdeploy.config import ModelConfig
|
||||
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
|
||||
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
|
||||
from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
|
||||
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
|
||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||
from fastdeploy.multimodal.registry import MultimodalRegistry
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -109,10 +110,10 @@ class EngineClient:
|
||||
"""
|
||||
Create a ZMQ client.
|
||||
"""
|
||||
self.zmq_client = ZmqClient(model, mode)
|
||||
self.zmq_client = ZmqIpcClient(model, mode)
|
||||
self.zmq_client.connect()
|
||||
|
||||
def format_and_add_data(self, prompts: dict):
|
||||
async def format_and_add_data(self, prompts: dict):
|
||||
"""
|
||||
Format the request data and send the request to the server.
|
||||
"""
|
||||
@@ -123,10 +124,10 @@ class EngineClient:
|
||||
if "max_tokens" not in prompts:
|
||||
prompts["max_tokens"] = self.max_model_len - 1
|
||||
|
||||
self.add_requests(prompts)
|
||||
await self.add_requests(prompts)
|
||||
return prompts["prompt_token_ids"]
|
||||
|
||||
def add_requests(self, task):
|
||||
async def add_requests(self, task):
|
||||
"""
|
||||
Add a new request to the queue.
|
||||
|
||||
@@ -140,13 +141,14 @@ class EngineClient:
|
||||
|
||||
task["preprocess_start_time"] = time.time()
|
||||
try:
|
||||
self.data_processor.process_request_dict(task, self.max_model_len)
|
||||
if inspect.iscoroutinefunction(self.data_processor.process_request_dict):
|
||||
await self.data_processor.process_request_dict(task, self.max_model_len)
|
||||
else:
|
||||
self.data_processor.process_request_dict(task, self.max_model_len)
|
||||
|
||||
task["prompt_token_ids_len"] = len(task["prompt_token_ids"])
|
||||
input_ids_len = task["prompt_token_ids_len"]
|
||||
task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens"))
|
||||
if task.get("reasoning_max_tokens", None) is None:
|
||||
task["reasoning_max_tokens"] = max(int(task["max_tokens"] * 0.8), 1)
|
||||
min_tokens = task.get("min_tokens", 1)
|
||||
if "messages" in task:
|
||||
del task["messages"]
|
||||
@@ -225,8 +227,8 @@ class EngineClient:
|
||||
raise ValueError(f"max_tokens can be defined [1, {self.max_model_len}).")
|
||||
|
||||
if data.get("reasoning_max_tokens") is not None:
|
||||
if data["reasoning_max_tokens"] > data["max_tokens"] or data["reasoning_max_tokens"] < 1:
|
||||
raise ValueError("reasoning_max_tokens must be between max_tokens and 1")
|
||||
if data["reasoning_max_tokens"] > data["max_tokens"] or data["reasoning_max_tokens"] < 0:
|
||||
raise ValueError("reasoning_max_tokens must be between max_tokens and 0")
|
||||
|
||||
if data.get("top_p") is not None:
|
||||
if data["top_p"] > 1 or data["top_p"] < 0:
|
||||
@@ -355,3 +357,6 @@ class EngineClient:
|
||||
return False, "clear model weight timeout"
|
||||
time.sleep(1)
|
||||
return True, ""
|
||||
|
||||
def check_model_weight_status(self):
|
||||
return self.model_weights_status_signal.value[0] < 0
|
||||
|
@@ -171,6 +171,7 @@ async def lifespan(app: FastAPI):
|
||||
workers=args.workers,
|
||||
tool_parser=args.tool_call_parser,
|
||||
)
|
||||
await engine_client.connection_manager.initialize()
|
||||
app.state.dynamic_load_weight = args.dynamic_load_weight
|
||||
model_handler = OpenAIServingModels(
|
||||
model_paths,
|
||||
@@ -477,7 +478,8 @@ def reset_scheduler():
|
||||
|
||||
if llm_engine is None:
|
||||
return Response("Engine not loaded", status_code=500)
|
||||
llm_engine.scheduler.reset()
|
||||
llm_engine.engine.scheduler.reset()
|
||||
llm_engine.engine.clear_data()
|
||||
return Response("Scheduler Reset Successfully", status_code=200)
|
||||
|
||||
|
||||
@@ -495,11 +497,14 @@ def control_scheduler(request: ControlSchedulerRequest):
|
||||
return JSONResponse(content=content.model_dump(), status_code=500)
|
||||
|
||||
if request.reset:
|
||||
llm_engine.scheduler.reset()
|
||||
llm_engine.engine.clear_data()
|
||||
llm_engine.engine.scheduler.reset()
|
||||
|
||||
if request.load_shards_num or request.reallocate_shard:
|
||||
if hasattr(llm_engine.scheduler, "update_config") and callable(llm_engine.scheduler.update_config):
|
||||
llm_engine.scheduler.update_config(
|
||||
if hasattr(llm_engine.engine.scheduler, "update_config") and callable(
|
||||
llm_engine.engine.scheduler.update_config
|
||||
):
|
||||
llm_engine.engine.scheduler.update_config(
|
||||
load_shards_num=request.load_shards_num,
|
||||
reallocate=request.reallocate_shard,
|
||||
)
|
||||
|
@@ -25,23 +25,33 @@ from fastdeploy.utils import get_logger, is_port_available
|
||||
logger = get_logger("multi_api_server", "multi_api_server.log")
|
||||
|
||||
|
||||
def start_servers(server_count, server_args, ports, metrics_ports):
|
||||
def start_servers(server_count, server_args, ports, metrics_ports, controller_ports):
|
||||
processes = []
|
||||
logger.info(f"Starting servers on ports: {ports} with args: {server_args} and metrics ports: {metrics_ports}")
|
||||
for i in range(len(server_args)):
|
||||
if server_args[i] == "--engine-worker-queue-port":
|
||||
engine_worker_queue_port = server_args[i + 1].split(",")
|
||||
break
|
||||
check_param(ports, server_count)
|
||||
check_param(metrics_ports, server_count)
|
||||
check_param(engine_worker_queue_port, server_count)
|
||||
if not check_param(ports, server_count):
|
||||
return
|
||||
if not check_param(metrics_ports, server_count):
|
||||
return
|
||||
if not check_param(engine_worker_queue_port, server_count):
|
||||
return
|
||||
if controller_ports != "-1":
|
||||
controller_ports = controller_ports.split(",")
|
||||
if not check_param(controller_ports, server_count):
|
||||
return
|
||||
else:
|
||||
controller_ports = [-1] * server_count
|
||||
# check_param(server_args, server_count)
|
||||
for i in range(server_count):
|
||||
port = int(ports[i])
|
||||
metrics_port = int(metrics_ports[i])
|
||||
controller_port = int(controller_ports[i])
|
||||
|
||||
env = os.environ.copy()
|
||||
env["FD_LOG_DIR"] = f"log_{i}"
|
||||
env["FD_LOG_DIR"] = env.get("FD_LOG_DIR", "log") + f"/log_{i}"
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
@@ -51,6 +61,8 @@ def start_servers(server_count, server_args, ports, metrics_ports):
|
||||
str(port),
|
||||
"--metrics-port",
|
||||
str(metrics_port),
|
||||
"--controller-port",
|
||||
str(controller_port),
|
||||
"--local-data-parallel-id",
|
||||
str(i),
|
||||
]
|
||||
@@ -69,7 +81,8 @@ def check_param(ports, num_servers):
|
||||
for port in ports:
|
||||
logger.info(f"check port {port}")
|
||||
if not is_port_available("0.0.0.0", int(port)):
|
||||
raise ValueError(f"Port {port} is already in use.")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
@@ -77,6 +90,7 @@ def main():
|
||||
parser.add_argument("--ports", default="8000,8002", type=str, help="ports to the http server")
|
||||
parser.add_argument("--num-servers", default=2, type=int, help="number of workers")
|
||||
parser.add_argument("--metrics-ports", default="8800,8802", type=str, help="ports for metrics server")
|
||||
parser.add_argument("--controller-ports", default="-1", type=str, help="ports for controller server port")
|
||||
parser.add_argument("--args", nargs=argparse.REMAINDER, help="remaining arguments are passed to api_server.py")
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -90,6 +104,7 @@ def main():
|
||||
server_args=args.args,
|
||||
ports=args.ports.split(","),
|
||||
metrics_ports=args.metrics_ports.split(","),
|
||||
controller_ports=args.controller_ports,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@@ -74,12 +74,6 @@ class OpenAIServingChat:
|
||||
self.master_ip = "0.0.0.0"
|
||||
api_server_logger.info(f"master ip: {self.master_ip}")
|
||||
|
||||
async def _ensure_connection_manager(self):
|
||||
"""ensure connection manager initialized"""
|
||||
if not self.engine_client.connection_initialized:
|
||||
await self.engine_client.connection_manager.initialize()
|
||||
self.engine_client.connection_initialized = True
|
||||
|
||||
def _check_master(self):
|
||||
return self.engine_client.is_master
|
||||
|
||||
@@ -119,7 +113,7 @@ class OpenAIServingChat:
|
||||
if "chat_template" not in current_req_dict:
|
||||
current_req_dict["chat_template"] = self.chat_template
|
||||
current_req_dict["arrival_time"] = time.time()
|
||||
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict)
|
||||
prompt_token_ids = await self.engine_client.format_and_add_data(current_req_dict)
|
||||
text_after_process = current_req_dict.get("text_after_process")
|
||||
if isinstance(prompt_token_ids, np.ndarray):
|
||||
prompt_token_ids = prompt_token_ids.tolist()
|
||||
@@ -182,7 +176,7 @@ class OpenAIServingChat:
|
||||
if request.max_streaming_response_tokens is not None
|
||||
else (request.metadata or {}).get("max_streaming_response_tokens", 1)
|
||||
) # dierctly passed & passed in metadata
|
||||
|
||||
max_streaming_response_tokens = max(max_streaming_response_tokens, 1)
|
||||
enable_thinking = request.chat_template_kwargs.get("enable_thinking") if request.chat_template_kwargs else None
|
||||
if enable_thinking is None:
|
||||
enable_thinking = request.metadata.get("enable_thinking") if request.metadata else None
|
||||
@@ -206,7 +200,6 @@ class OpenAIServingChat:
|
||||
api_server_logger.info(f"create chat completion request: {request_id}")
|
||||
|
||||
try:
|
||||
await self._ensure_connection_manager()
|
||||
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)
|
||||
dealer.write([b"", request_id.encode("utf-8")])
|
||||
choices = []
|
||||
@@ -217,6 +210,8 @@ class OpenAIServingChat:
|
||||
decoder_base_url=self.tokenizer_base_url,
|
||||
)
|
||||
while num_choices > 0:
|
||||
if self.engine_client.check_model_weight_status():
|
||||
raise ValueError("Engine is clearing model weight")
|
||||
try:
|
||||
response = await asyncio.wait_for(response_queue.get(), timeout=10)
|
||||
current_waiting_time = 0
|
||||
@@ -323,7 +318,9 @@ class OpenAIServingChat:
|
||||
continue
|
||||
delta_message.content = delta_message_output.content or ""
|
||||
delta_message.reasoning_content = delta_message_output.reasoning_content or ""
|
||||
delta_message.tool_calls = delta_message_output.tool_calls
|
||||
if delta_message_output.tool_calls:
|
||||
delta_message.tool_calls = delta_message_output.tool_calls
|
||||
tool_called = True
|
||||
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
@@ -370,11 +367,6 @@ class OpenAIServingChat:
|
||||
api_server_logger.info(f"Chat Streaming response last send: {chunk.model_dump_json()}")
|
||||
choices = []
|
||||
|
||||
if choices:
|
||||
chunk.choices = choices
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
choices = []
|
||||
|
||||
if include_usage:
|
||||
completion_tokens = previous_num_tokens
|
||||
usage = UsageInfo(
|
||||
@@ -422,7 +414,6 @@ class OpenAIServingChat:
|
||||
|
||||
include_stop_str_in_output = request.include_stop_str_in_output
|
||||
try:
|
||||
await self._ensure_connection_manager()
|
||||
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)
|
||||
dealer.write([b"", request_id.encode("utf-8")])
|
||||
final_res = None
|
||||
@@ -436,6 +427,8 @@ class OpenAIServingChat:
|
||||
decoder_base_url=self.tokenizer_base_url,
|
||||
)
|
||||
while True:
|
||||
if self.engine_client.check_model_weight_status():
|
||||
return ErrorResponse(code=400, message="Model weight cleared")
|
||||
try:
|
||||
response = await asyncio.wait_for(response_queue.get(), timeout=10)
|
||||
current_waiting_time = 0
|
||||
@@ -524,6 +517,7 @@ class OpenAIServingChat:
|
||||
|
||||
if final_res.get("error_msg") is not None and "Recover" in final_res["error_msg"]:
|
||||
choice.finish_reason = "recover_stop"
|
||||
|
||||
choices.append(choice)
|
||||
|
||||
num_prompt_tokens = len(prompt_token_ids)
|
||||
|
@@ -51,12 +51,6 @@ class OpenAIServingCompletion:
|
||||
else:
|
||||
self.master_ip = "0.0.0.0"
|
||||
|
||||
async def _ensure_connection_manager(self):
|
||||
"""ensure connection manager initialized"""
|
||||
if not self.engine_client.connection_initialized:
|
||||
await self.engine_client.connection_manager.initialize()
|
||||
self.engine_client.connection_initialized = True
|
||||
|
||||
def _check_master(self):
|
||||
return self.engine_client.is_master
|
||||
|
||||
@@ -146,7 +140,7 @@ class OpenAIServingCompletion:
|
||||
request_id_idx = f"{request_id}-{idx}"
|
||||
current_req_dict = request.to_dict_for_infer(request_id_idx, prompt)
|
||||
current_req_dict["arrival_time"] = time.time()
|
||||
prompt_token_ids = self.engine_client.format_and_add_data(current_req_dict) # tokenize
|
||||
prompt_token_ids = await self.engine_client.format_and_add_data(current_req_dict) # tokenize
|
||||
if isinstance(prompt_token_ids, np.ndarray):
|
||||
prompt_token_ids = prompt_token_ids.tolist()
|
||||
text_after_process_list.append(current_req_dict.get("text_after_process"))
|
||||
@@ -208,7 +202,6 @@ class OpenAIServingCompletion:
|
||||
try:
|
||||
request_ids = [f"{request_id}-{i}" for i in range(num_choices)]
|
||||
# create dealer
|
||||
await self._ensure_connection_manager()
|
||||
dealer, response_queue = await self.engine_client.connection_manager.get_connection(
|
||||
request_id, num_choices
|
||||
)
|
||||
@@ -223,6 +216,8 @@ class OpenAIServingCompletion:
|
||||
completion_batched_token_ids = [[] for _ in range(num_choices)]
|
||||
current_waiting_time = 0
|
||||
while num_choices > 0:
|
||||
if self.engine_client.check_model_weight_status():
|
||||
return ErrorResponse(message="Model weight cleared", code=400)
|
||||
try:
|
||||
response = await asyncio.wait_for(response_queue.get(), timeout=10)
|
||||
current_waiting_time = 0
|
||||
@@ -277,7 +272,6 @@ class OpenAIServingCompletion:
|
||||
return res
|
||||
except Exception as e:
|
||||
api_server_logger.error(f"Error in completion_full_generator: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
self.engine_client.semaphore.release()
|
||||
if dealer is not None:
|
||||
@@ -314,7 +308,6 @@ class OpenAIServingCompletion:
|
||||
Process the stream completion request.
|
||||
"""
|
||||
try:
|
||||
await self._ensure_connection_manager()
|
||||
dealer, response_queue = await self.engine_client.connection_manager.get_connection(
|
||||
request_id, num_choices
|
||||
)
|
||||
@@ -331,6 +324,7 @@ class OpenAIServingCompletion:
|
||||
if request.max_streaming_response_tokens is not None
|
||||
else (request.suffix or {}).get("max_streaming_response_tokens", 1)
|
||||
) # dierctly passed & passed in suffix
|
||||
max_streaming_response_tokens = max(max_streaming_response_tokens, 1)
|
||||
choices = []
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
@@ -340,6 +334,8 @@ class OpenAIServingCompletion:
|
||||
)
|
||||
current_waiting_time = 0
|
||||
while num_choices > 0:
|
||||
if self.engine_client.check_model_weight_status():
|
||||
raise ValueError("Engine is clearing model weight")
|
||||
try:
|
||||
response = await asyncio.wait_for(response_queue.get(), timeout=10)
|
||||
current_waiting_time = 0
|
||||
@@ -417,7 +413,9 @@ class OpenAIServingCompletion:
|
||||
continue
|
||||
delta_message.text = delta_message_output.content or ""
|
||||
delta_message.reasoning_content = delta_message_output.reasoning_content or ""
|
||||
delta_message.tool_calls = delta_message_output.tool_calls
|
||||
if delta_message_output.tool_calls:
|
||||
delta_message.tool_calls = delta_message_output.tool_calls
|
||||
tool_called[idx] = True
|
||||
|
||||
choices.append(delta_message)
|
||||
|
||||
@@ -461,10 +459,6 @@ class OpenAIServingCompletion:
|
||||
)
|
||||
yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
api_server_logger.info(f"Completion Streaming response last send: {chunk.model_dump_json()}")
|
||||
if choices:
|
||||
chunk.choices = choices
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
choices = []
|
||||
|
||||
except Exception as e:
|
||||
api_server_logger.error(f"Error in completion_stream_generator: {e}, {str(traceback.format_exc())}")
|
||||
|
@@ -81,7 +81,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# set traec exporter_otlp_headers.
|
||||
"EXPORTER_OTLP_HEADERS": lambda: os.getenv("EXPORTER_OTLP_HEADERS"),
|
||||
# enable kv cache block scheduler v1 (no need for kv_cache_ratio)
|
||||
"ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")),
|
||||
"ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "1")),
|
||||
# Whether to use PLUGINS.
|
||||
"FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","),
|
||||
# set trace attribute job_id.
|
||||
@@ -93,8 +93,24 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# enable multi api server
|
||||
"FD_ENABLE_MULTI_API_SERVER": lambda: bool(int(os.getenv("FD_ENABLE_MULTI_API_SERVER", "0"))),
|
||||
"FD_FOR_TORCH_MODEL_FORMAT": lambda: bool(int(os.getenv("FD_FOR_TORCH_MODEL_FORMAT", "0"))),
|
||||
# force enable chunked prefill
|
||||
"FD_FORCE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_FORCE_CHUNKED_PREFILL", "0"))),
|
||||
# force disable default chunked prefill
|
||||
"FD_DISABLE_CHUNKED_PREFILL": lambda: bool(int(os.getenv("FD_DISABLE_CHUNKED_PREFILL", "0"))),
|
||||
# For separate setting of sampling parameters for speculative decoding
|
||||
"FD_SPECULATE_SAMPLING_TOP_P": lambda: (
|
||||
None if "FD_SPECULATE_SAMPLING_TOP_P" not in os.environ else float(os.environ["FD_SPECULATE_SAMPLING_TOP_P"])
|
||||
),
|
||||
"FD_SPECULATE_SAMPLING_TOP_K": lambda: (
|
||||
None if "FD_SPECULATE_SAMPLING_TOP_K" not in os.environ else float(os.environ["FD_SPECULATE_SAMPLING_TOP_K"])
|
||||
),
|
||||
"FD_ENABLE_INTERNAL_ADAPTER": lambda: int(os.getenv("FD_ENABLE_INTERNAL_ADAPTER", "0")),
|
||||
# LLMEngine recieve requests port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ZMQ_RECV_REQUEST_SERVER_PORT": lambda: os.getenv("FD_ZMQ_RECV_REQUEST_SERVER_PORT", "8200"),
|
||||
# LLMEngine send response port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ZMQ_SEND_RESPONSE_SERVER_PORT": lambda: os.getenv("FD_ZMQ_SEND_RESPONSE_SERVER_PORT", "8201"),
|
||||
# LLMEngine recieve control command port, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"),
|
||||
# enable return text, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"FD_ENABLE_RETURN_TEXT": lambda: bool(int(os.getenv("FD_ENABLE_RETURN_TEXT", "0"))),
|
||||
}
|
||||
|
||||
|
||||
@@ -105,5 +121,10 @@ def __getattr__(name: str):
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
def __setattr__(name: str, value: Any):
|
||||
assert name in environment_variables
|
||||
environment_variables[name] = lambda: value
|
||||
|
||||
|
||||
def __dir__():
|
||||
return list(environment_variables.keys())
|
||||
|
@@ -255,6 +255,10 @@ class Ernie4_5_VLProcessor(Ernie4_5Processor):
|
||||
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
|
||||
if request.get("max_tokens") is None:
|
||||
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
|
||||
else:
|
||||
request["max_tokens"] = min(max_model_len - len(request["prompt_token_ids"]), request["max_tokens"])
|
||||
if request.get("reasoning_max_tokens") is None:
|
||||
request["reasoning_max_tokens"] = max(int(request["max_tokens"] * 0.8), 1)
|
||||
data_processor_logger.info(f"Processed request {request}")
|
||||
|
||||
return request
|
||||
|
@@ -71,6 +71,7 @@ class InputPreprocessor:
|
||||
"""
|
||||
reasoning_parser_obj = None
|
||||
tool_parser_obj = None
|
||||
|
||||
if self.reasoning_parser:
|
||||
reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(self.reasoning_parser)
|
||||
if self.tool_parser:
|
||||
@@ -85,6 +86,8 @@ class InputPreprocessor:
|
||||
Processor = load_input_processor_plugins()
|
||||
self.processor = Processor(
|
||||
model_name_or_path=self.model_name_or_path,
|
||||
reasoning_parser_obj=reasoning_parser_obj,
|
||||
tool_parser_obj=tool_parser_obj,
|
||||
)
|
||||
except:
|
||||
if not self.enable_mm:
|
||||
|
@@ -69,7 +69,7 @@ class QwenVLProcessor(TextProcessor):
|
||||
tokenizer=self.tokenizer,
|
||||
**processor_kwargs,
|
||||
)
|
||||
|
||||
self.image_patch_id = self.processor.image_token_id
|
||||
self.limit_mm_per_prompt = self._parse_limits(limit_mm_per_prompt)
|
||||
|
||||
def process_request(self, request, max_model_len=None, **kwargs):
|
||||
@@ -231,6 +231,15 @@ class QwenVLProcessor(TextProcessor):
|
||||
elif request.get("messages"):
|
||||
messages = request["messages"]
|
||||
self._check_mm_limits(messages)
|
||||
chat_template_kwargs = request.get("chat_template_kwargs")
|
||||
if chat_template_kwargs:
|
||||
if isinstance(chat_template_kwargs, dict):
|
||||
for k, v in chat_template_kwargs.items():
|
||||
if k not in request:
|
||||
request[k] = v
|
||||
else:
|
||||
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
|
||||
request.setdefault("enable_thinking", True)
|
||||
outputs = self.processor.request2ids(request)
|
||||
|
||||
else:
|
||||
@@ -240,6 +249,16 @@ class QwenVLProcessor(TextProcessor):
|
||||
# Handle continuation of previous generation by appending existing tokens
|
||||
if metadata and metadata.get("generated_token_ids"):
|
||||
self.append_generated_tokens(outputs, metadata["generated_token_ids"])
|
||||
|
||||
enable_thinking = False
|
||||
if metadata:
|
||||
enable_thinking = metadata.get("enable_thinking", False)
|
||||
|
||||
if request.get("chat_template_kwargs"):
|
||||
chat_template_kwargs = request.get("chat_template_kwargs")
|
||||
enable_thinking = chat_template_kwargs.get("enable_thinking", False)
|
||||
request["enable_thinking"] = enable_thinking
|
||||
|
||||
outputs = self.pack_outputs(outputs)
|
||||
|
||||
request["prompt_token_ids"] = outputs["input_ids"].tolist()
|
||||
|
@@ -17,6 +17,15 @@
|
||||
from .engine_cache_queue import EngineCacheQueue
|
||||
from .engine_worker_queue import EngineWorkerQueue
|
||||
from .ipc_signal import IPCSignal, shared_memory_exists
|
||||
from .zmq_client import ZmqClient
|
||||
from .zmq_client import ZmqIpcClient
|
||||
from .zmq_server import ZmqIpcServer, ZmqTcpServer
|
||||
|
||||
__all__ = ["ZmqClient", "IPCSignal", "EngineWorkerQueue", "EngineCacheQueue", "shared_memory_exists"]
|
||||
__all__ = [
|
||||
"ZmqIpcClient",
|
||||
"IPCSignal",
|
||||
"EngineWorkerQueue",
|
||||
"EngineCacheQueue",
|
||||
"ZmqTcpServer",
|
||||
"ZmqIpcServer",
|
||||
"shared_memory_exists",
|
||||
]
|
||||
|
@@ -392,6 +392,13 @@ class EngineWorkerQueue:
|
||||
llm_logger.debug("get tasks from queue success")
|
||||
return item
|
||||
|
||||
def clear_data(self):
|
||||
self.lock.acquire()
|
||||
self.tasks[:] = list()
|
||||
self.client_read_flag[:] = [1] * self.num_client
|
||||
self.lock.release()
|
||||
llm_logger.info("clear data for engine worker queue")
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Exit the worker queue gracefully.
|
||||
|
@@ -14,209 +14,78 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import msgpack
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.utils import zmq_client_logger
|
||||
|
||||
|
||||
class ZmqClient:
|
||||
class ZmqClientBase(ABC):
|
||||
"""
|
||||
ZmqClient is a class that provides a client-side interface for sending and receiving messages using ZeroMQ.
|
||||
ZmqClientBase is a base class that provides a client-side interface for sending and receiving messages using ZeroMQ.
|
||||
"""
|
||||
|
||||
def __init__(self, name, mode):
|
||||
self.context = zmq.Context(4)
|
||||
self.socket = self.context.socket(mode)
|
||||
self.file_name = f"/dev/shm/{name}.socket"
|
||||
self.router_path = f"/dev/shm/router_{name}.ipc"
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||
@abstractmethod
|
||||
def _create_socket(self):
|
||||
"""Abstract method to create and return a ZeroMQ socket."""
|
||||
pass
|
||||
|
||||
self.mutex = threading.Lock()
|
||||
self.req_dict = dict()
|
||||
self.router = None
|
||||
self.poller = None
|
||||
self.running = True
|
||||
def _ensure_socket(self):
|
||||
"""Ensure the socket is created before use."""
|
||||
if self.socket is None:
|
||||
self.socket = self._create_socket()
|
||||
|
||||
@abstractmethod
|
||||
def connect(self):
|
||||
"""
|
||||
Connect to the server using the file name specified in the constructor.
|
||||
"""
|
||||
self.socket.connect(f"ipc://{self.file_name}")
|
||||
|
||||
def start_server(self):
|
||||
"""
|
||||
Start the server using the file name specified in the constructor.
|
||||
"""
|
||||
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||
self.socket.setsockopt(zmq.SNDTIMEO, -1)
|
||||
self.socket.bind(f"ipc://{self.file_name}")
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.socket, zmq.POLLIN)
|
||||
|
||||
def create_router(self):
|
||||
"""
|
||||
Create a ROUTER socket and bind it to the specified router path.
|
||||
"""
|
||||
self.router = self.context.socket(zmq.ROUTER)
|
||||
self.router.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||
self.router.setsockopt(zmq.ROUTER_MANDATORY, 1)
|
||||
self.router.setsockopt(zmq.SNDTIMEO, -1)
|
||||
self.router.bind(f"ipc://{self.router_path}")
|
||||
zmq_client_logger.info(f"router path: {self.router_path}")
|
||||
pass
|
||||
|
||||
def send_json(self, data):
|
||||
"""
|
||||
Send a JSON-serializable object over the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
self.socket.send_json(data)
|
||||
|
||||
def recv_json(self):
|
||||
"""
|
||||
Receive a JSON-serializable object from the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
return self.socket.recv_json()
|
||||
|
||||
def send_pyobj(self, data):
|
||||
"""
|
||||
Send a Pickle-serializable object over the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
self.socket.send_pyobj(data)
|
||||
|
||||
def recv_pyobj(self):
|
||||
"""
|
||||
Receive a Pickle-serializable object from the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
return self.socket.recv_pyobj()
|
||||
|
||||
def pack_aggregated_data(self, data):
|
||||
"""
|
||||
Aggregate multiple responses into one and send them to the client.
|
||||
"""
|
||||
result = data[0]
|
||||
if len(data) > 1:
|
||||
for response in data[1:]:
|
||||
result.add(response)
|
||||
result = msgpack.packb([result.to_dict()])
|
||||
return result
|
||||
|
||||
def send_multipart(self, req_id, data):
|
||||
"""
|
||||
Send a multipart message to the router socket.
|
||||
"""
|
||||
if self.router is None:
|
||||
raise RuntimeError("Router socket not created. Call create_router() first.")
|
||||
class ZmqIpcClient(ZmqClientBase):
|
||||
def __init__(self, name, mode):
|
||||
self.name = name
|
||||
self.mode = mode
|
||||
self.file_name = f"/dev/shm/{name}.socket"
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(self.mode)
|
||||
|
||||
while self.running:
|
||||
with self.mutex:
|
||||
if req_id not in self.req_dict:
|
||||
try:
|
||||
client, _, request_id = self.router.recv_multipart(flags=zmq.NOBLOCK)
|
||||
req_id_str = request_id.decode("utf-8")
|
||||
self.req_dict[req_id_str] = client
|
||||
except zmq.Again:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
else:
|
||||
break
|
||||
if self.req_dict[req_id] == -1:
|
||||
if data[-1].finished:
|
||||
with self.mutex:
|
||||
self.req_dict.pop(req_id, None)
|
||||
return
|
||||
try:
|
||||
start_send = time.time()
|
||||
if self.aggregate_send:
|
||||
result = self.pack_aggregated_data(data)
|
||||
else:
|
||||
result = msgpack.packb([response.to_dict() for response in data])
|
||||
self.router.send_multipart([self.req_dict[req_id], b"", result])
|
||||
zmq_client_logger.info(f"send_multipart result: {req_id} len {len(data)} elapse: {time.time()-start_send}")
|
||||
except zmq.ZMQError as e:
|
||||
zmq_client_logger.error(f"[{req_id}] zmq error: {e}")
|
||||
self.req_dict[req_id] = -1
|
||||
except Exception as e:
|
||||
zmq_client_logger.error(f"Send result to zmq client failed: {e}, {str(traceback.format_exc())}")
|
||||
def _create_socket(self):
|
||||
"""create and return a ZeroMQ socket."""
|
||||
self.context = zmq.Context()
|
||||
return self.context.socket(self.mode)
|
||||
|
||||
if data[-1].finished:
|
||||
with self.mutex:
|
||||
self.req_dict.pop(req_id, None)
|
||||
zmq_client_logger.info(f"send_multipart finished, req_id: {req_id}")
|
||||
|
||||
def receive_json_once(self, block=False):
|
||||
"""
|
||||
Receive a single message from the socket.
|
||||
"""
|
||||
if self.socket is None or self.socket.closed:
|
||||
return "zmp socket has closed", None
|
||||
try:
|
||||
flags = zmq.NOBLOCK if not block else 0
|
||||
return None, self.socket.recv_json(flags=flags)
|
||||
except zmq.Again:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.close()
|
||||
zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}")
|
||||
return str(e), None
|
||||
|
||||
def receive_pyobj_once(self, block=False):
|
||||
"""
|
||||
Receive a single message from the socket.
|
||||
"""
|
||||
if self.socket is None or self.socket.closed:
|
||||
return "zmp socket has closed", None
|
||||
try:
|
||||
flags = zmq.NOBLOCK if not block else 0
|
||||
return None, self.socket.recv_pyobj(flags=flags)
|
||||
except zmq.Again:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.close()
|
||||
zmq_client_logger.warning(f"{e}, {str(traceback.format_exc())}")
|
||||
return str(e), None
|
||||
|
||||
def _clear_ipc(self, name):
|
||||
"""
|
||||
Remove the IPC file with the given name.
|
||||
"""
|
||||
if os.path.exists(name):
|
||||
try:
|
||||
os.remove(name)
|
||||
except OSError as e:
|
||||
zmq_client_logger.warning(f"Failed to remove IPC file {name} - {e}")
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the socket and context, and remove the IPC files.
|
||||
"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
zmq_client_logger.info("Closing ZMQ connection...")
|
||||
try:
|
||||
if hasattr(self, "socket") and not self.socket.closed:
|
||||
self.socket.close()
|
||||
|
||||
if self.router is not None and not self.router.closed:
|
||||
self.router.close()
|
||||
|
||||
if not self.context.closed:
|
||||
self.context.term()
|
||||
|
||||
self._clear_ipc(self.file_name)
|
||||
self._clear_ipc(self.router_path)
|
||||
except Exception as e:
|
||||
zmq_client_logger.warning(f"Failed to close ZMQ connection - {e}, {str(traceback.format_exc())}")
|
||||
return
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
def connect(self):
|
||||
self._ensure_socket()
|
||||
self.socket.connect(f"ipc://{self.file_name}")
|
||||
|
304
fastdeploy/inter_communicator/zmq_server.py
Normal file
304
fastdeploy/inter_communicator/zmq_server.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
import msgpack
|
||||
import zmq
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
|
||||
class ZmqServerBase(ABC):
|
||||
"""
|
||||
ZmqServerBase
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.cached_results = defaultdict(list)
|
||||
self.response_token_lock = threading.Lock()
|
||||
|
||||
@abstractmethod
|
||||
def _create_socket(self):
|
||||
"""Abstract method to create and return a ZeroMQ socket."""
|
||||
pass
|
||||
|
||||
def _ensure_socket(self):
|
||||
"""Ensure the socket is created before use."""
|
||||
if self.socket is None:
|
||||
self.socket = self._create_socket()
|
||||
|
||||
def pack_aggregated_data(self, data):
|
||||
"""
|
||||
Aggregate multiple responses into one and send them to the client.
|
||||
"""
|
||||
result = data[0]
|
||||
if len(data) > 1:
|
||||
for response in data[1:]:
|
||||
result.add(response)
|
||||
result = msgpack.packb([result.to_dict()])
|
||||
return result
|
||||
|
||||
def receive_json_once(self, block=False):
|
||||
"""
|
||||
Receive a single message from the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
if self.socket is None or self.socket.closed:
|
||||
return "zmp socket has closed", None
|
||||
try:
|
||||
flags = zmq.NOBLOCK if not block else 0
|
||||
return None, self.socket.recv_json(flags=flags)
|
||||
except zmq.Again:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.close()
|
||||
llm_logger.warning(f"{e}")
|
||||
return str(e), None
|
||||
|
||||
def receive_pyobj_once(self, block=False):
|
||||
"""
|
||||
Receive a single message from the socket.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
if self.socket is None or self.socket.closed:
|
||||
return "zmp socket has closed", None
|
||||
try:
|
||||
flags = zmq.NOBLOCK if not block else 0
|
||||
result = self.socket.recv_pyobj(flags=flags)
|
||||
llm_logger.info(f"receive one pyobj {result}")
|
||||
return None, result
|
||||
except zmq.Again:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.close()
|
||||
llm_logger.warning(f"{e}")
|
||||
return str(e), None
|
||||
|
||||
def recv_result_handle(self):
|
||||
while True:
|
||||
try:
|
||||
with self.response_token_lock:
|
||||
client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK)
|
||||
req_id_str = request_id.decode("utf-8")
|
||||
with self.mutex:
|
||||
self.req_dict[req_id_str] = client
|
||||
except zmq.Again:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
except Exception as e:
|
||||
llm_logger.error(f"recv_result_handle get unknown exception: {e}")
|
||||
continue
|
||||
|
||||
def send_response(self, req_id, data):
|
||||
"""
|
||||
Send generated token result to client.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
if self.socket is None:
|
||||
raise RuntimeError("Router socket not created. Call create_router() first.")
|
||||
new_data = []
|
||||
has_result_handle = False
|
||||
with self.mutex:
|
||||
if req_id not in self.req_dict:
|
||||
self.cached_results[req_id].append(data)
|
||||
else:
|
||||
has_result_handle = True
|
||||
if req_id in self.cached_results:
|
||||
for history_data in self.cached_results[req_id]:
|
||||
new_data.extend(history_data)
|
||||
llm_logger.info(
|
||||
f"get request {req_id} result handle after cached result, total cached length {len(self.cached_results[req_id])}"
|
||||
)
|
||||
del self.cached_results[req_id]
|
||||
if has_result_handle:
|
||||
try:
|
||||
new_data.extend(data)
|
||||
start_send = time.time()
|
||||
if self.aggregate_send:
|
||||
result = self.pack_aggregated_data(new_data)
|
||||
else:
|
||||
result = msgpack.packb([response.to_dict() for response in new_data])
|
||||
with self.response_token_lock:
|
||||
self.socket.send_multipart([self.req_dict[req_id], b"", result])
|
||||
llm_logger.debug(
|
||||
f"send_multipart result: {req_id} len {len(new_data)} elapse: {time.time()-start_send}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Send result to zmq client failed: {e}")
|
||||
|
||||
if data[-1].finished:
|
||||
with self.mutex:
|
||||
if req_id not in self.req_dict:
|
||||
llm_logger.warning(f"req_id {req_id} finished but no result handle, drop it")
|
||||
if req_id in self.cached_results:
|
||||
del self.cached_results[req_id]
|
||||
else:
|
||||
llm_logger.info(f"send_multipart finished, req_id: {req_id}")
|
||||
self.req_dict.pop(req_id, None)
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
|
||||
class ZmqIpcServer(ZmqServerBase):
|
||||
"""
|
||||
ZmqIpcServer, used when FD_ENABLE_INTERNAL_ADAPTER=0
|
||||
"""
|
||||
|
||||
def __init__(self, name, mode):
|
||||
self.name = name
|
||||
self.mode = mode
|
||||
self.cached_results = defaultdict(list)
|
||||
if mode == zmq.PULL:
|
||||
self.file_name = f"/dev/shm/{name}.socket"
|
||||
elif mode == zmq.ROUTER:
|
||||
self.file_name = f"/dev/shm/router_{name}.ipc"
|
||||
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||
self.mutex = threading.Lock()
|
||||
self.response_token_lock = threading.Lock()
|
||||
self.req_dict = dict()
|
||||
self.running = True
|
||||
self.context = zmq.Context()
|
||||
self._create_socket()
|
||||
|
||||
def _create_socket(self):
|
||||
"""create and return a ZeroMQ socket."""
|
||||
self.socket = self.context.socket(self.mode)
|
||||
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||
self.socket.setsockopt(zmq.SNDTIMEO, -1)
|
||||
self.socket.bind(f"ipc://{self.file_name}")
|
||||
return self.socket
|
||||
|
||||
def _clear_ipc(self, name):
|
||||
"""
|
||||
Remove the IPC file with the given name.
|
||||
"""
|
||||
if os.path.exists(name):
|
||||
try:
|
||||
os.remove(name)
|
||||
except OSError as e:
|
||||
llm_logger.warning(f"Failed to remove IPC file {name} - {e}")
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the socket and context, and remove the IPC files.
|
||||
"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
llm_logger.info("Closing ZMQ connection...")
|
||||
try:
|
||||
if self.socket is not None and not self.socket.closed:
|
||||
self.socket.close()
|
||||
if not self.context.closed:
|
||||
self.context.term()
|
||||
self._clear_ipc(self.file_name)
|
||||
except Exception as e:
|
||||
llm_logger.warning(f"Failed to close ZMQ connection - {e}")
|
||||
return
|
||||
|
||||
|
||||
class ZmqTcpServer(ZmqServerBase):
|
||||
"""
|
||||
ZmqTcpServer, used when FD_ENABLE_INTERNAL_ADAPTER=1
|
||||
"""
|
||||
|
||||
def __init__(self, port, mode):
|
||||
self.mode = mode
|
||||
self.port = port
|
||||
self.cached_results = defaultdict(list)
|
||||
self.ZMQ_SNDHWM = int(envs.FD_ZMQ_SNDHWM)
|
||||
self.aggregate_send = envs.FD_USE_AGGREGATE_SEND
|
||||
|
||||
self.mutex = threading.Lock()
|
||||
self.req_dict = dict()
|
||||
self.running = True
|
||||
self.context = zmq.Context()
|
||||
self._create_socket()
|
||||
self.response_token_lock = threading.Lock()
|
||||
|
||||
def _create_socket(self):
|
||||
"""create and return a ZeroMQ socket."""
|
||||
self.socket = self.context.socket(self.mode)
|
||||
self.socket.setsockopt(zmq.SNDHWM, self.ZMQ_SNDHWM)
|
||||
self.socket.setsockopt(zmq.SNDTIMEO, -1)
|
||||
self.socket.bind(f"tcp://*:{self.port}")
|
||||
return self.socket
|
||||
|
||||
def recv_control_cmd(self):
|
||||
"""
|
||||
Recieve control command from client
|
||||
"""
|
||||
self._ensure_socket()
|
||||
try:
|
||||
client, _, task_data = self.socket.recv_multipart(flags=zmq.NOBLOCK)
|
||||
task = msgpack.unpackb(task_data)
|
||||
task_id_str = task["task_id"]
|
||||
except zmq.Again:
|
||||
return None
|
||||
with self.mutex:
|
||||
self.req_dict[task_id_str] = client
|
||||
return task
|
||||
|
||||
def response_for_control_cmd(self, task_id, result):
|
||||
"""
|
||||
Send command result back to client.
|
||||
"""
|
||||
self._ensure_socket()
|
||||
if self.socket is None:
|
||||
raise RuntimeError("Router socket not created.")
|
||||
try:
|
||||
result = msgpack.packb(result)
|
||||
self.socket.send_multipart([self.req_dict[task_id], b"", result])
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Send result to zmq client failed: {e}")
|
||||
|
||||
with self.mutex:
|
||||
self.req_dict.pop(task_id, None)
|
||||
llm_logger.debug(f"response control cmd finished, task_id: {task_id}")
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the socket and context.
|
||||
"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
llm_logger.info("Closing ZMQ connection...")
|
||||
try:
|
||||
if self.socket is not None and not self.socket.closed:
|
||||
self.socket.close()
|
||||
if not self.context.closed:
|
||||
self.context.term()
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.warning(f"Failed to close ZMQ connection - {e}")
|
||||
return
|
@@ -19,6 +19,7 @@ metrics
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import Set
|
||||
|
||||
from prometheus_client import (
|
||||
@@ -35,24 +36,20 @@ from fastdeploy.metrics import build_1_2_5_buckets
|
||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||
|
||||
|
||||
def cleanup_prometheus_files(is_main):
|
||||
def cleanup_prometheus_files(is_main: bool, instance_id: str = None):
|
||||
"""
|
||||
Cleans and recreates the Prometheus multiprocess directory.
|
||||
|
||||
Depending on whether it's the main process or a worker, this function removes the corresponding
|
||||
Prometheus multiprocess directory (/tmp/prom_main or /tmp/prom_worker) and recreates it as an empty directory.
|
||||
|
||||
Args:
|
||||
is_main (bool): Indicates whether the current process is the main process.
|
||||
|
||||
Returns:
|
||||
str: The path to the newly created Prometheus multiprocess directory.
|
||||
"""
|
||||
PROM_DIR = "/tmp/prom_main" if is_main else "/tmp/prom_worker"
|
||||
if os.path.exists(PROM_DIR):
|
||||
shutil.rmtree(PROM_DIR)
|
||||
os.makedirs(PROM_DIR, exist_ok=True)
|
||||
return PROM_DIR
|
||||
base_dir = "/tmp/prom_main" if is_main else "/tmp/prom_worker"
|
||||
if instance_id is None:
|
||||
instance_id = str(uuid.uuid4())
|
||||
prom_dir = f"{base_dir}_{instance_id}"
|
||||
|
||||
if os.path.exists(prom_dir):
|
||||
shutil.rmtree(prom_dir, ignore_errors=True)
|
||||
os.makedirs(prom_dir, exist_ok=True)
|
||||
|
||||
return prom_dir
|
||||
|
||||
|
||||
class SimpleCollector(Collector):
|
||||
|
@@ -15,13 +15,15 @@
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import paddle.jit.dy2static.utils as jit_utils
|
||||
import paddle.nn.layer
|
||||
from paddle.base.core import CUDAGraph
|
||||
from paddle.device.cuda import graphs
|
||||
from paddle.jit.dy2static.utils import CUDAGraphState
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication import capture_custom_allreduce
|
||||
from fastdeploy.utils import get_logger
|
||||
@@ -46,28 +48,30 @@ class ConcreteSizeEntry:
|
||||
num_finished_warmup: int = 0
|
||||
# Captured cuda graph object corresponding to the current real shape
|
||||
cuda_graph: Optional[graphs.CUDAGraph] = None
|
||||
# Output buffer of cudagraph
|
||||
output_buffer: Optional[paddle.Tensor] = None
|
||||
# Output buffers of cudagraph
|
||||
output_buffers: List[Optional[paddle.Tensor]] = field(default_factory=list)
|
||||
|
||||
|
||||
class Dy2StCudaGraphManager:
|
||||
def __init__(self):
|
||||
self.state = CUDAGraphState.DISABLE
|
||||
|
||||
self.state = jit_utils.CUDAGraphState.DISABLE
|
||||
self.captured_batch_size = set()
|
||||
self.batch_size = -1
|
||||
|
||||
def run_impl(self, original_run_impl, inputs, parameters, attrs):
|
||||
|
||||
run_state = self.state
|
||||
prog_attrs, cuda_graph_attrs = attrs
|
||||
if run_state == CUDAGraphState.REPLAY:
|
||||
if run_state == jit_utils.CUDAGraphState.REPLAY:
|
||||
if self.batch_size not in self.captured_batch_size:
|
||||
run_state = CUDAGraphState.DISABLE
|
||||
elif run_state == CUDAGraphState.CAPTURE:
|
||||
run_state = jit_utils.CUDAGraphState.DISABLE
|
||||
elif run_state == jit_utils.CUDAGraphState.CAPTURE:
|
||||
self.captured_batch_size.add(self.batch_size)
|
||||
|
||||
cuda_graph_attrs |= {
|
||||
"cuda_graph_state": run_state,
|
||||
"cuda_graph_dispatch_key": self.batch_size if run_state != CUDAGraphState.DISABLE else 0,
|
||||
"cuda_graph_dispatch_key": self.batch_size if run_state != jit_utils.CUDAGraphState.DISABLE else 0,
|
||||
}
|
||||
return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs))
|
||||
|
||||
@@ -82,17 +86,14 @@ class Dy2StCudaGraphManager:
|
||||
class CudaGraphPiecewiseBackend:
|
||||
"""Manage the capture and replay of CUDA graphs at the subgraph level."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
runnable: Callable,
|
||||
):
|
||||
def __init__(self, fd_config: FDConfig, runnable: Callable):
|
||||
self.fd_config = fd_config
|
||||
self.runnable = runnable
|
||||
self.cudagraph_capture_sizes = fd_config.graph_opt_config.cudagraph_capture_sizes
|
||||
self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups
|
||||
self.real_shape_to_captured_size = fd_config.graph_opt_config.real_shape_to_captured_size
|
||||
|
||||
if self.fd_config.graph_opt_config.use_unique_memory_pool:
|
||||
self.unique_memory_pool_id = CUDAGraph.gen_new_memory_pool_id()
|
||||
self._create_entry_dict()
|
||||
|
||||
self.cuda_graph_manager = None
|
||||
@@ -100,6 +101,7 @@ class CudaGraphPiecewiseBackend:
|
||||
self.cuda_graph_manager = Dy2StCudaGraphManager()
|
||||
|
||||
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
|
||||
|
||||
if not entry.captured:
|
||||
# Warmup the model
|
||||
for n in range(entry.num_finished_warmup, self.warm_up_size):
|
||||
@@ -115,21 +117,21 @@ class CudaGraphPiecewiseBackend:
|
||||
entry.input_addresses = input_addresses
|
||||
|
||||
# Capture
|
||||
self.cuda_graph_manager.state = CUDAGraphState.CAPTURE
|
||||
self.cuda_graph_manager.state = jit_utils.CUDAGraphState.CAPTURE
|
||||
self.cuda_graph_manager.batch_size = entry.real_shape
|
||||
entry.captured = True
|
||||
with self.cuda_graph_manager.run_impl_guard():
|
||||
entry.runnable(**kwargs)
|
||||
|
||||
# Replay
|
||||
self.cuda_graph_manager.state = CUDAGraphState.REPLAY
|
||||
self.cuda_graph_manager.state = jit_utils.CUDAGraphState.REPLAY
|
||||
self.cuda_graph_manager.batch_size = entry.real_shape
|
||||
with self.cuda_graph_manager.run_impl_guard():
|
||||
return entry.runnable(**kwargs)
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
# Get real shape(all num tokens)
|
||||
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
|
||||
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
|
||||
real_shape = ids_remove_padding.shape[0]
|
||||
padding_real_shape = self.real_shape_to_captured_size[real_shape]
|
||||
logger.debug(
|
||||
@@ -164,20 +166,32 @@ class CudaGraphPiecewiseBackend:
|
||||
input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)]
|
||||
entry.input_addresses = input_addresses
|
||||
|
||||
new_grpah = graphs.CUDAGraph()
|
||||
new_grpah = (
|
||||
graphs.CUDAGraph(pool_id=self.unique_memory_pool_id)
|
||||
if self.fd_config.graph_opt_config.use_unique_memory_pool
|
||||
else graphs.CUDAGraph()
|
||||
)
|
||||
paddle.device.synchronize()
|
||||
|
||||
# Capture
|
||||
with capture_custom_allreduce():
|
||||
new_grpah.capture_begin()
|
||||
output = entry.runnable(**kwargs)
|
||||
outputs = entry.runnable(**kwargs)
|
||||
if isinstance(outputs, paddle.Tensor):
|
||||
assert outputs is not None
|
||||
outputs = [outputs]
|
||||
new_grpah.capture_end()
|
||||
|
||||
# Store output buffer
|
||||
entry.cuda_graph = new_grpah
|
||||
entry.output_buffer = paddle.zeros_like(output)
|
||||
output._share_buffer_to(entry.output_buffer)
|
||||
output._clear
|
||||
for output in outputs:
|
||||
if output is not None:
|
||||
output_buffer = paddle.zeros_like(output)
|
||||
output._share_buffer_to(output_buffer)
|
||||
output._clear
|
||||
entry.output_buffers.append(output_buffer)
|
||||
else:
|
||||
entry.output_buffers.append(None)
|
||||
|
||||
paddle.device.synchronize()
|
||||
|
||||
@@ -188,7 +202,9 @@ class CudaGraphPiecewiseBackend:
|
||||
# Replay
|
||||
entry.cuda_graph.replay()
|
||||
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}")
|
||||
return entry.output_buffer
|
||||
if len(entry.output_buffers) == 1:
|
||||
return entry.output_buffers[0]
|
||||
return entry.output_buffers
|
||||
|
||||
def _create_entry_dict(self):
|
||||
""" """
|
||||
@@ -218,8 +234,9 @@ class CudaGraphPiecewiseBackend:
|
||||
|
||||
def _save_cudagrpah_dot_files(self, entry):
|
||||
"""Print CUDAGrpah to dot files"""
|
||||
log_dir = envs.FD_LOG_DIR
|
||||
if entry.cuda_graph:
|
||||
entry.cuda_graph.print_to_dot_files(
|
||||
f"./log/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}",
|
||||
f"./{log_dir}/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}",
|
||||
1 << 0,
|
||||
)
|
||||
|
@@ -100,6 +100,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
|
||||
fd_config.model_config, "use_3d_rope", False
|
||||
)
|
||||
if fd_config.speculative_config.model_type != "main":
|
||||
self.rope_3d = False
|
||||
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
||||
self.speculative_method: str = fd_config.speculative_config.method
|
||||
self.use_speculate: bool = self.speculative_method is not None
|
||||
@@ -231,6 +233,17 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.kv_signal_metadata,
|
||||
layer.layer_id + self.start_layer_index,
|
||||
)
|
||||
cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none")
|
||||
if cache_quant_type_str == "block_wise_fp8":
|
||||
cache_k = forward_meta.caches[4 * layer.layer_id]
|
||||
cache_v = forward_meta.caches[4 * layer.layer_id + 1]
|
||||
cache_k_scales = forward_meta.caches[4 * layer.layer_id + 2]
|
||||
cache_v_scales = forward_meta.caches[4 * layer.layer_id + 3]
|
||||
else:
|
||||
cache_k = forward_meta.caches[2 * layer.layer_id]
|
||||
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
|
||||
cache_k_scales = getattr(layer, "cache_k_scale", None)
|
||||
cache_v_scales = getattr(layer, "cache_v_scale", None)
|
||||
|
||||
if self.use_output:
|
||||
quant_max_bound = getattr(layer, "quant_max_bound", 0.0)
|
||||
@@ -269,8 +282,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
|
||||
append_attention_with_output(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
cache_k,
|
||||
cache_v,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
@@ -293,8 +306,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.attn_mask,
|
||||
layer.qkv_bias,
|
||||
layer.qkv_scale,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
cache_k_scales,
|
||||
cache_v_scales,
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "cache_k_zp", None),
|
||||
@@ -325,8 +338,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
else:
|
||||
res = append_attention(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
cache_k,
|
||||
cache_v,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
@@ -348,15 +361,15 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.attn_mask,
|
||||
layer.qkv_bias,
|
||||
layer.qkv_scale,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
cache_k_scales,
|
||||
cache_v_scales,
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "cache_k_zp", None),
|
||||
getattr(layer, "cache_v_zp", None),
|
||||
layer.linear_shift,
|
||||
layer.linear_smooth,
|
||||
forward_meta.attn_mask_offsets,
|
||||
None if self.use_speculate else forward_meta.attn_mask_offsets,
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
getattr(layer, "q_norm_weight", None),
|
||||
getattr(layer, "k_norm_weight", None),
|
||||
@@ -374,7 +387,7 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.max_partition_size,
|
||||
metadata.encoder_max_partition_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
self.causal,
|
||||
self.causal or self.use_speculate,
|
||||
self.speculative_method is not None,
|
||||
)
|
||||
return res
|
||||
|
@@ -24,6 +24,9 @@ from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.quantization.kv_cache import (
|
||||
KvCacheQuantzationTypes,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -102,6 +105,12 @@ class Attention(nn.Layer):
|
||||
|
||||
if fd_config.quant_config and hasattr(fd_config.quant_config, "kv_cache_quant_type"):
|
||||
self.kvcache_quant_method: QuantMethodBase = fd_config.quant_config.get_quant_method(self)
|
||||
|
||||
# set for RL model, as RL do not need load state dict
|
||||
if fd_config.quant_config.kv_cache_quant_type == KvCacheQuantzationTypes.BLOCK_WISE_FP8:
|
||||
self.cache_quant_type_str = "block_wise_fp8"
|
||||
self.quant_max_bound = 448.0
|
||||
self.quant_min_bound = -448.0
|
||||
else:
|
||||
self.kvcache_quant_method = None
|
||||
|
||||
|
@@ -359,6 +359,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
getattr(layer, "cache_v_zp", None),
|
||||
layer.linear_shift,
|
||||
layer.linear_smooth,
|
||||
forward_meta.attn_mask_offsets,
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
getattr(layer, "q_norm_weight", None),
|
||||
getattr(layer, "k_norm_weight", None),
|
||||
|
@@ -77,6 +77,11 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
)
|
||||
if self.world_size > 1:
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
|
||||
set_weight_attrs(
|
||||
self.embeddings.weight,
|
||||
{"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}},
|
||||
)
|
||||
|
||||
else:
|
||||
# column cut embedding
|
||||
self.embeddings = nn.Embedding(
|
||||
|
@@ -356,11 +356,21 @@ class ColumnParallelLinear(LinearBase):
|
||||
)
|
||||
|
||||
if self.nranks > 0:
|
||||
_set_var_distributed(self.weight, split_axis=-1)
|
||||
if self.with_bias:
|
||||
# col parallel
|
||||
_set_var_distributed(self.bias, split_axis=1)
|
||||
_set_var_distributed(self.bias, split_axis=0)
|
||||
set_weight_attrs(self.bias, {"output_dim": True})
|
||||
|
||||
# set_rl_tp_degree
|
||||
set_weight_attrs(
|
||||
self.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
|
||||
)
|
||||
if self.with_bias:
|
||||
set_weight_attrs(
|
||||
self.bias, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
|
||||
)
|
||||
|
||||
|
||||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"""
|
||||
@@ -415,6 +425,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
model_format = getattr(param, "model_format", "")
|
||||
if model_format == "torch":
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
assert output_dim is not None
|
||||
@@ -446,7 +457,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
shard_offset = self.local_rank * block_size
|
||||
shard_size = (self.local_rank + 1) * block_size
|
||||
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
|
||||
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
if not param._is_initialized():
|
||||
param.initialize()
|
||||
param_shard_size = output_size // 2
|
||||
@@ -548,6 +559,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank)
|
||||
model_format = getattr(param, "model_format", "")
|
||||
if model_format == "torch":
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already fused on disk
|
||||
@@ -568,12 +580,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if self.nranks != 1:
|
||||
block_size = self._get_shard_size_mapping(loaded_shard_id)
|
||||
dim = -1 if output_dim else 0
|
||||
shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas
|
||||
shard_offset = shard_id * block_size
|
||||
shard_size = (shard_id + 1) * block_size
|
||||
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
|
||||
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
|
||||
if not param._is_initialized():
|
||||
param.initialize()
|
||||
|
||||
@@ -740,6 +753,7 @@ class RowParallelLinear(LinearBase):
|
||||
model_format=fd_config.model_config.model_format,
|
||||
)
|
||||
if self.nranks > 0:
|
||||
_set_var_distributed(self.weight, split_axis=0)
|
||||
if self.with_bias:
|
||||
# col parallel
|
||||
_set_var_distributed(self.bias, split_axis=0)
|
||||
@@ -752,6 +766,11 @@ class RowParallelLinear(LinearBase):
|
||||
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
# set_rl_tp_degree
|
||||
set_weight_attrs(
|
||||
self.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
|
||||
)
|
||||
|
||||
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
if self.fd_config.quant_config:
|
||||
out = self.quant_method.apply(self, x)
|
||||
|
@@ -94,6 +94,12 @@ class ParallelLMHead(nn.Layer):
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
if self.bias_key is not None:
|
||||
set_weight_attrs(
|
||||
self.linear.bias,
|
||||
{"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}},
|
||||
)
|
||||
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
else:
|
||||
@@ -116,6 +122,9 @@ class ParallelLMHead(nn.Layer):
|
||||
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
||||
set_weight_attrs(
|
||||
self.linear.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
|
@@ -18,7 +18,6 @@ from abc import abstractmethod
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.base.core import Config
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
try:
|
||||
@@ -26,138 +25,152 @@ try:
|
||||
except:
|
||||
logger.warning("import deep_ep Failed!")
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import fastdeploy
|
||||
from fastdeploy.config import MoEPhase
|
||||
from fastdeploy.utils import singleton
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import noaux_tc
|
||||
except:
|
||||
logger.warning("import noaux_tc Failed!")
|
||||
|
||||
class DeepEPBufferManager:
|
||||
_engine: Optional["DeepEPEngine"] = None
|
||||
|
||||
@classmethod
|
||||
def set_engine(cls, engine: "DeepEPEngine"):
|
||||
cls._engine = engine
|
||||
|
||||
@classmethod
|
||||
def clear_buffer(cls):
|
||||
if cls._engine:
|
||||
cls._engine.clear_deep_ep_buffer()
|
||||
|
||||
@classmethod
|
||||
def recreate_buffer(cls):
|
||||
if cls._engine:
|
||||
cls._engine.create_deep_ep_buffer()
|
||||
|
||||
|
||||
def get_moe_scores(
|
||||
gating_output: paddle.Tensor,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
) -> paddle.Tensor:
|
||||
class DeepEPBuffer:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
|
||||
scores_with_bias = scores + e_score_correction_bias
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group if n_group > 0 else 1,
|
||||
topk_group if topk_group > 0 else 1,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
@singleton
|
||||
class DeepEPEngine:
|
||||
"""
|
||||
A wrapper class for DeepEP engine.
|
||||
Encapsulates DeepEP buffer creation, management and cleanup.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
hidden: int,
|
||||
group,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
ep_size: int,
|
||||
ep_rank: int,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
splitwise_role: str,
|
||||
moe_phase: MoEPhase,
|
||||
async_finish: bool = False,
|
||||
use_internode_ll_two_stage: bool = False,
|
||||
top_k: int = 8,
|
||||
):
|
||||
"""
|
||||
Initialize the DeepEP engine.
|
||||
Args:
|
||||
group: The MPI group object.
|
||||
ep_size: The number of ranks.
|
||||
rank_id: The rank id.
|
||||
num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
|
||||
hidden: The hidden dimension of the model.
|
||||
num_experts: The number of experts.
|
||||
"""
|
||||
# TODO(@wufeisheng): Support configurable EP size
|
||||
self.group = paddle.distributed.new_group(range(ep_size))
|
||||
self.ep_size = ep_size
|
||||
self.rank_id = ep_rank
|
||||
self.hidden = hidden
|
||||
self.group = group
|
||||
self.hidden_size = hidden_size
|
||||
self.num_experts = num_experts
|
||||
self.num_local_experts = num_experts // ep_size
|
||||
self.async_finish = async_finish
|
||||
|
||||
self.deepep_engine = None
|
||||
|
||||
self.ep_config = Config(24, 6, 256)
|
||||
self.ep_size = ep_size
|
||||
self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank
|
||||
self.splitwise_role = splitwise_role
|
||||
self.moe_phase = moe_phase
|
||||
self.use_internode_ll_two_stage = use_internode_ll_two_stage
|
||||
self.top_k = top_k
|
||||
|
||||
# In mixed EP mode on a single node, we dynamically switch between
|
||||
# high throughput and low latency modes.
|
||||
self.deepep_buffer = None
|
||||
self.num_nvl_bytes = 0
|
||||
self.num_rdma_bytes = 0
|
||||
|
||||
if splitwise_role == "mixed":
|
||||
self.deepep_engine = deep_ep.Buffer(
|
||||
# Precompute buffer sizes
|
||||
self._compute_buffer_sizes()
|
||||
|
||||
def _compute_buffer_sizes(self, param_bytes: int = 2):
|
||||
hidden_bytes = self.hidden_size * param_bytes # bf16 or fp16
|
||||
|
||||
for config in (
|
||||
deep_ep.Buffer.get_dispatch_config(self.group.world_size),
|
||||
deep_ep.Buffer.get_combine_config(self.group.world_size),
|
||||
):
|
||||
self.num_nvl_bytes = max(
|
||||
config.get_nvl_buffer_size_hint(hidden_bytes, self.group.world_size), self.num_nvl_bytes
|
||||
)
|
||||
self.num_rdma_bytes = max(
|
||||
config.get_rdma_buffer_size_hint(hidden_bytes, self.group.world_size), self.num_rdma_bytes
|
||||
)
|
||||
|
||||
if self.splitwise_role == "mixed" or self.moe_phase.phase == "decode":
|
||||
if not self.use_internode_ll_two_stage:
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.hidden_size,
|
||||
self.ep_size,
|
||||
self.num_experts,
|
||||
)
|
||||
else:
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint_two_stage(
|
||||
self.num_max_dispatch_tokens_per_rank, self.hidden_size, self.ep_size, self.num_experts, self.top_k
|
||||
)
|
||||
num_nvl_bytes = deep_ep.Buffer.get_low_latency_nvl_size_hint_two_stage(
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.hidden_size,
|
||||
self.ep_size,
|
||||
self.num_experts,
|
||||
self.top_k,
|
||||
True, # just supports dispatch_use_fp8 = True now!
|
||||
)
|
||||
self.num_nvl_bytes = max(self.num_nvl_bytes, num_nvl_bytes)
|
||||
self.num_rdma_bytes = max(self.num_rdma_bytes, num_rdma_bytes)
|
||||
|
||||
logger.info(f"DeepEP num nvl bytes : {self.num_nvl_bytes}, num rdma bytes : {self.num_rdma_bytes}")
|
||||
|
||||
def create_buffer(self):
|
||||
"""Create or recreate buffer based on role and phase."""
|
||||
if self.deepep_buffer is not None:
|
||||
self.clear_buffer()
|
||||
|
||||
if self.splitwise_role == "mixed":
|
||||
logger.info("Initializing mixed mode buffer (low latency).")
|
||||
self.deepep_buffer = deep_ep.Buffer(
|
||||
self.group,
|
||||
int(2e9),
|
||||
int(6e9),
|
||||
self.num_nvl_bytes,
|
||||
self.num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=24,
|
||||
)
|
||||
# In disaggregated mode on mutiple nodes, we either use
|
||||
# high throughput mode or low latency mode.
|
||||
self.deepep_buffer.set_num_sms(14) # TODO: tune in future
|
||||
else:
|
||||
if moe_phase.phase == "decode":
|
||||
logger.info("Initializing Low Latency Buffer")
|
||||
self.get_low_latency_buffer()
|
||||
elif moe_phase.phase == "prefill":
|
||||
self.deepep_engine = deep_ep.Buffer(
|
||||
if self.moe_phase.phase == "decode":
|
||||
self._create_low_latency_buffer()
|
||||
elif self.moe_phase.phase == "prefill":
|
||||
logger.info("Initializing High Throughput Buffer for prefill phase.")
|
||||
self.deepep_buffer = deep_ep.Buffer(
|
||||
self.group,
|
||||
int(5e8),
|
||||
self.num_nvl_bytes,
|
||||
0,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=1,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown generation phase {moe_phase}")
|
||||
raise ValueError(f"Unknown generation phase: {self.moe_phase.phase}")
|
||||
|
||||
def get_low_latency_buffer(self):
|
||||
"""
|
||||
Get the DeepEP buffer.
|
||||
Args:
|
||||
group: The MPI group object.
|
||||
num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch.
|
||||
hidden: The hidden dimension of the model.
|
||||
"""
|
||||
# NOTES: the low-latency mode will consume much more space than the normal mode
|
||||
# So we recommend that `num_max_dispatch_tokens_per_rank`
|
||||
# (the actual batch size in the decoding engine) should be less than 256
|
||||
logger.info("DeepEP buffer created successfully.")
|
||||
|
||||
def _create_low_latency_buffer(self):
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.hidden,
|
||||
self.hidden_size,
|
||||
self.ep_size,
|
||||
self.num_experts,
|
||||
)
|
||||
# Allocate a buffer if not existed or not enough buffer size
|
||||
|
||||
if (
|
||||
self.deepep_engine is None
|
||||
or self.deepep_engine.group != self.group
|
||||
or not self.deepep_engine.low_latency_mode
|
||||
or self.deepep_engine.num_rdma_bytes < num_rdma_bytes
|
||||
self.deepep_buffer is None
|
||||
or self.deepep_buffer.group != self.group
|
||||
or not self.deepep_buffer.low_latency_mode
|
||||
or self.deepep_buffer.num_rdma_bytes < num_rdma_bytes
|
||||
):
|
||||
# NOTES: for best performance, the QP number **must** be equal to the number of the local experts
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.deepep_engine = deep_ep.Buffer(
|
||||
self.deepep_buffer = deep_ep.Buffer(
|
||||
self.group,
|
||||
0,
|
||||
num_rdma_bytes,
|
||||
@@ -165,6 +178,105 @@ class DeepEPEngine:
|
||||
num_qps_per_rank=self.num_experts // self.ep_size,
|
||||
)
|
||||
|
||||
def clear_buffer(self):
|
||||
"""Clear buffer and free memory."""
|
||||
if self.deepep_buffer is not None:
|
||||
del self.deepep_buffer
|
||||
self.deepep_buffer = None
|
||||
logger.info("DeepEP buffer cleared.")
|
||||
|
||||
def get_buffer(self):
|
||||
return self.deepep_buffer
|
||||
|
||||
def clean_low_latency_buffer(self):
|
||||
if self.deepep_buffer is not None:
|
||||
if not self.use_internode_ll_two_stage:
|
||||
self.deepep_buffer.clean_low_latency_buffer(
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.hidden_size,
|
||||
self.num_experts,
|
||||
)
|
||||
else:
|
||||
self.deepep_buffer.clean_low_latency_two_stage_buffer(
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.hidden_size,
|
||||
self.num_experts,
|
||||
self.top_k,
|
||||
self.ep_size,
|
||||
True, # just supports dispatch_use_fp8 = True now!
|
||||
)
|
||||
|
||||
def barrier_all(self):
|
||||
if self.deepep_buffer is not None:
|
||||
self.deepep_buffer.barrier_all()
|
||||
|
||||
|
||||
@singleton
|
||||
class DeepEPEngine:
|
||||
"""
|
||||
A wrapper class for DeepEP engine.
|
||||
Manages buffer lifecycle based on role and phase.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
ep_size: int,
|
||||
ep_rank: int,
|
||||
splitwise_role: str,
|
||||
moe_phase: MoEPhase,
|
||||
async_finish: bool = False,
|
||||
group=None,
|
||||
use_internode_ll_two_stage: bool = False,
|
||||
top_k: int = 8,
|
||||
):
|
||||
if group is None:
|
||||
group = paddle.distributed.new_group(range(ep_size))
|
||||
self.group = group
|
||||
self.ep_size = ep_size
|
||||
self.rank_id = ep_rank
|
||||
self.hidden_size = hidden_size
|
||||
self.num_experts = num_experts
|
||||
self.num_local_experts = num_experts // ep_size
|
||||
self.top_k = top_k
|
||||
self.async_finish = async_finish
|
||||
|
||||
self.ep_config = None
|
||||
|
||||
# Store phase and role for buffer management
|
||||
self._splitwise_role = splitwise_role
|
||||
self._moe_phase = moe_phase
|
||||
|
||||
# Initialize buffer manager
|
||||
self.buffer = DeepEPBuffer(
|
||||
group=self.group,
|
||||
hidden_size=hidden_size,
|
||||
num_experts=num_experts,
|
||||
ep_size=ep_size,
|
||||
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
||||
splitwise_role=splitwise_role,
|
||||
moe_phase=moe_phase,
|
||||
use_internode_ll_two_stage=use_internode_ll_two_stage,
|
||||
top_k=self.top_k,
|
||||
)
|
||||
self.buffer.create_buffer()
|
||||
|
||||
# Register for global buffer management
|
||||
DeepEPBufferManager.set_engine(self)
|
||||
|
||||
@property
|
||||
def deepep_engine(self):
|
||||
"""Backward compatibility alias."""
|
||||
return self.buffer.get_buffer()
|
||||
|
||||
def clear_deep_ep_buffer(self):
|
||||
self.buffer.clear_buffer()
|
||||
|
||||
def create_deep_ep_buffer(self):
|
||||
self.buffer.create_buffer()
|
||||
|
||||
def low_latency_dispatch(
|
||||
self,
|
||||
hidden_states: paddle.Tensor,
|
||||
@@ -172,22 +284,9 @@ class DeepEPEngine:
|
||||
expertwise_scale,
|
||||
use_fp8: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [token_num, hidden] 'bfloat16/int8'
|
||||
topk_idx: [token_num, num_topk] 'int64'
|
||||
if self.deepep_engine is None:
|
||||
raise RuntimeError("DeepEP buffer not initialized!")
|
||||
|
||||
Returns:
|
||||
recv_hidden_states: [num_local_experts,
|
||||
num_max_dispatch_tokens_per_rank * ep_size, hidden]
|
||||
ep_size * num_local_experts = num_experts
|
||||
recv_count: [num_local_experts]
|
||||
recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each
|
||||
expert receive. As mentioned before, all not tokens are valid in `recv_x`.
|
||||
handle: the communication handle to be used in the `low_latency_combine` function.
|
||||
event: the event after executing the kernel (valid only if `async_finish` is set).
|
||||
hook: the receiving hook function (valid only if `return_recv_hook` is set).
|
||||
"""
|
||||
(
|
||||
packed_recv_x,
|
||||
recv_expert_count,
|
||||
@@ -198,7 +297,7 @@ class DeepEPEngine:
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
expertwise_scale,
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.buffer.num_max_dispatch_tokens_per_rank,
|
||||
self.num_experts,
|
||||
use_fp8=use_fp8,
|
||||
async_finish=False,
|
||||
@@ -207,6 +306,37 @@ class DeepEPEngine:
|
||||
|
||||
return packed_recv_x, recv_expert_count, handle, dispatch_hook
|
||||
|
||||
def low_latency_dispatch_two_stage(
|
||||
self,
|
||||
hidden_states: paddle.Tensor,
|
||||
topk_idx: paddle.Tensor,
|
||||
topk_weights: paddle.Tensor,
|
||||
expertwise_scale,
|
||||
use_fp8: bool = False,
|
||||
):
|
||||
if self.deepep_engine is None:
|
||||
raise RuntimeError("DeepEP buffer not initialized!")
|
||||
|
||||
(
|
||||
packed_recv_x,
|
||||
packed_recv_count,
|
||||
_,
|
||||
handle,
|
||||
_,
|
||||
dispatch_hook,
|
||||
) = self.deepep_engine.low_latency_dispatch_two_stage(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
self.buffer.num_max_dispatch_tokens_per_rank,
|
||||
self.num_experts,
|
||||
use_fp8=use_fp8,
|
||||
async_finish=False,
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
return packed_recv_x, packed_recv_count, handle, dispatch_hook
|
||||
|
||||
def low_latency_combine(
|
||||
self,
|
||||
hidden_states: paddle.Tensor,
|
||||
@@ -214,27 +344,14 @@ class DeepEPEngine:
|
||||
topk_weights: paddle.Tensor,
|
||||
handle,
|
||||
):
|
||||
"""
|
||||
|
||||
Return:
|
||||
combined_hidden_states: [num_tokens, hidden]
|
||||
"""
|
||||
if paddle.__version__ != "0.0.0" and paddle.__version__ <= "3.1.0": # not develop version of PaddlePaddle
|
||||
if paddle.__version__ != "0.0.0" and paddle.__version__ <= "3.1.0":
|
||||
# TODO(@wanglongzhi): Delete them when deepep in PaddlePaddle is fixed
|
||||
# and when the default recommended version of PaddlePaddle is greater than 3.1.0
|
||||
(
|
||||
src_info,
|
||||
layout_range,
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
num_experts,
|
||||
) = handle
|
||||
handle = (
|
||||
src_info,
|
||||
layout_range,
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
None,
|
||||
num_experts,
|
||||
)
|
||||
src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle
|
||||
handle = (src_info, layout_range, num_max_dispatch_tokens_per_rank, None, num_experts)
|
||||
|
||||
if self.deepep_engine is None:
|
||||
raise RuntimeError("DeepEP buffer not initialized!")
|
||||
|
||||
combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine(
|
||||
hidden_states,
|
||||
@@ -246,19 +363,33 @@ class DeepEPEngine:
|
||||
)
|
||||
return combined_hidden_states, combine_hook
|
||||
|
||||
def clean_low_latency_buffer(self):
|
||||
"""
|
||||
clean_low_latency_buffer
|
||||
"""
|
||||
self.deepep_engine.clean_low_latency_buffer(
|
||||
self.num_max_dispatch_tokens_per_rank, self.hidden, self.num_experts
|
||||
def low_latency_combine_two_stage(
|
||||
self,
|
||||
hidden_states: paddle.Tensor,
|
||||
topk_idx: paddle.Tensor,
|
||||
topk_weights: paddle.Tensor,
|
||||
dispatch_use_fp8: bool,
|
||||
handle,
|
||||
):
|
||||
if self.deepep_engine is None:
|
||||
raise RuntimeError("DeepEP buffer not initialized!")
|
||||
|
||||
combined_hidden_states, _, combine_hook = self.deepep_engine.low_latency_combine_two_stage(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
handle,
|
||||
async_finish=False,
|
||||
dispatch_use_fp8=dispatch_use_fp8,
|
||||
return_recv_hook=True,
|
||||
)
|
||||
return combined_hidden_states, combine_hook
|
||||
|
||||
def clean_low_latency_buffer(self):
|
||||
self.buffer.clean_low_latency_buffer()
|
||||
|
||||
def barrier_all(self):
|
||||
"""
|
||||
barrier_all
|
||||
"""
|
||||
self.deepep_engine.barrier_all()
|
||||
self.buffer.barrier_all()
|
||||
|
||||
|
||||
class EPRunner:
|
||||
@@ -269,7 +400,7 @@ class EPRunner:
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
hidden: int,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
splitwise_role: str,
|
||||
moe_phase: MoEPhase,
|
||||
@@ -277,24 +408,27 @@ class EPRunner:
|
||||
ep_size: int = 1,
|
||||
ep_rank: int = 0,
|
||||
redundant_experts_num: int = 0,
|
||||
ep_group=None,
|
||||
use_internode_ll_two_stage: bool = False,
|
||||
):
|
||||
self.top_k = top_k
|
||||
self.num_experts = num_experts
|
||||
self.redundant_experts_num = redundant_experts_num
|
||||
self.use_internode_ll_two_stage = use_internode_ll_two_stage
|
||||
self.ep_engine = DeepEPEngine(
|
||||
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
|
||||
hidden=hidden,
|
||||
hidden_size=hidden_size,
|
||||
num_experts=num_experts + redundant_experts_num,
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
splitwise_role=splitwise_role,
|
||||
moe_phase=moe_phase,
|
||||
group=ep_group,
|
||||
use_internode_ll_two_stage=self.use_internode_ll_two_stage,
|
||||
top_k=self.top_k,
|
||||
)
|
||||
|
||||
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
|
||||
"""
|
||||
moe_select
|
||||
"""
|
||||
if layer.redundant_table_manger is not None:
|
||||
(
|
||||
ep_rank_to_expert_id_list,
|
||||
@@ -310,12 +444,14 @@ class EPRunner:
|
||||
tokens_per_expert_stats_list=tokens_per_expert_stats_list,
|
||||
bias=layer.gate_correction_bias,
|
||||
moe_topk=self.top_k,
|
||||
apply_norm_weight=True, # apply_norm_weight
|
||||
apply_norm_weight=True,
|
||||
enable_softmax_top_k_fused=False,
|
||||
redundant_ep_rank_num_plus_one=layer.fd_config.model_config.redundant_experts_num + 1,
|
||||
)
|
||||
else:
|
||||
if layer.topk_method == "noaux_tc":
|
||||
from .moe import get_moe_scores
|
||||
|
||||
score, topk_weights, topk_idx = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
@@ -329,28 +465,28 @@ class EPRunner:
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
self.top_k,
|
||||
True, # apply_norm_weight,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
return topk_idx, topk_weights
|
||||
|
||||
@abstractmethod
|
||||
def dispatch(self, *args, **kwargs):
|
||||
"""
|
||||
dispatch
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def combine(self, *args, **kwargs):
|
||||
"""
|
||||
combine
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def clean_low_latency_buffer(self):
|
||||
self.ep_engine.clean_low_latency_buffer()
|
||||
|
||||
def clear_deep_ep_buffer(self):
|
||||
self.ep_engine.clear_deep_ep_buffer()
|
||||
|
||||
def create_deep_ep_buffer(self):
|
||||
self.ep_engine.create_deep_ep_buffer()
|
||||
|
||||
|
||||
class EPPrefillRunner(EPRunner):
|
||||
"""
|
||||
@@ -360,7 +496,7 @@ class EPPrefillRunner(EPRunner):
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
hidden: int,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
splitwise_role: str,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
@@ -368,10 +504,12 @@ class EPPrefillRunner(EPRunner):
|
||||
ep_rank: int = 0,
|
||||
redundant_experts_num: int = 0,
|
||||
moe_phase: MoEPhase = MoEPhase("prefill"),
|
||||
ep_group=None,
|
||||
use_internode_ll_two_stage: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
top_k,
|
||||
hidden,
|
||||
hidden_size,
|
||||
num_experts,
|
||||
splitwise_role,
|
||||
moe_phase,
|
||||
@@ -379,6 +517,8 @@ class EPPrefillRunner(EPRunner):
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
redundant_experts_num=redundant_experts_num,
|
||||
ep_group=ep_group,
|
||||
use_internode_ll_two_stage=use_internode_ll_two_stage,
|
||||
)
|
||||
|
||||
def dispatch(
|
||||
@@ -389,6 +529,9 @@ class EPPrefillRunner(EPRunner):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
buffer = self.ep_engine.deepep_engine
|
||||
if buffer is None:
|
||||
raise RuntimeError("DeepEP buffer not initialized!")
|
||||
|
||||
(
|
||||
num_tokens_per_rank,
|
||||
@@ -396,7 +539,7 @@ class EPPrefillRunner(EPRunner):
|
||||
num_tokens_per_expert,
|
||||
is_token_in_rank,
|
||||
_,
|
||||
) = self.ep_engine.deepep_engine.get_dispatch_layout(topk_idx, self.num_experts)
|
||||
) = buffer.get_dispatch_layout(topk_idx, self.num_experts)
|
||||
|
||||
x_scale_tensor = kwargs.get("x_scale_tensor", None)
|
||||
dispatch_args = {
|
||||
@@ -405,12 +548,12 @@ class EPPrefillRunner(EPRunner):
|
||||
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
|
||||
"is_token_in_rank": is_token_in_rank,
|
||||
"num_tokens_per_expert": num_tokens_per_expert,
|
||||
"config": self.ep_engine.ep_config,
|
||||
"config": self.ep_engine.ep_config, # assuming ep_config still in engine
|
||||
"async_finish": self.ep_engine.async_finish,
|
||||
"topk_idx": topk_idx,
|
||||
"topk_weights": topk_weights,
|
||||
}
|
||||
return self.ep_engine.deepep_engine.dispatch(**dispatch_args)
|
||||
return buffer.dispatch(**dispatch_args)
|
||||
|
||||
def combine(
|
||||
self,
|
||||
@@ -418,6 +561,10 @@ class EPPrefillRunner(EPRunner):
|
||||
handle: tuple,
|
||||
recv_topk_weights: paddle.Tensor,
|
||||
):
|
||||
buffer = self.ep_engine.deepep_engine
|
||||
if buffer is None:
|
||||
raise RuntimeError("DeepEP buffer not initialized!")
|
||||
|
||||
combine_args = {
|
||||
"x": tmp_ffn_out,
|
||||
"handle": handle,
|
||||
@@ -425,8 +572,7 @@ class EPPrefillRunner(EPRunner):
|
||||
"async_finish": self.ep_engine.async_finish,
|
||||
"topk_weights": recv_topk_weights,
|
||||
}
|
||||
fused_moe_out, _, _ = self.ep_engine.deepep_engine.combine(**combine_args)
|
||||
|
||||
fused_moe_out, _, _ = buffer.combine(**combine_args)
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
@@ -438,18 +584,20 @@ class EPDecoderRunner(EPRunner):
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
hidden: int,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
splitwise_role: str,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
ep_size: int = 1,
|
||||
ep_rank: int = 0,
|
||||
redundant_experts_num: int = 0,
|
||||
ep_group=None,
|
||||
moe_phase: MoEPhase = MoEPhase("decode"),
|
||||
use_internode_ll_two_stage: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
top_k,
|
||||
hidden,
|
||||
hidden_size,
|
||||
num_experts,
|
||||
splitwise_role,
|
||||
moe_phase,
|
||||
@@ -457,6 +605,8 @@ class EPDecoderRunner(EPRunner):
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
redundant_experts_num=redundant_experts_num,
|
||||
ep_group=ep_group,
|
||||
use_internode_ll_two_stage=use_internode_ll_two_stage,
|
||||
)
|
||||
|
||||
def dispatch(
|
||||
@@ -470,18 +620,30 @@ class EPDecoderRunner(EPRunner):
|
||||
expertwise_scale = kwargs.get("expertwise_scale", None)
|
||||
use_fp8 = kwargs.get("use_fp8", False)
|
||||
|
||||
recv_hidden_states, recv_expert_count, handle, dispatch_hook = self.ep_engine.low_latency_dispatch(
|
||||
x, topk_idx, expertwise_scale, use_fp8
|
||||
)
|
||||
if not self.use_internode_ll_two_stage:
|
||||
recv_hidden_states, recv_expert_count, handle, dispatch_hook = self.ep_engine.low_latency_dispatch(
|
||||
x, topk_idx, expertwise_scale, use_fp8
|
||||
)
|
||||
else:
|
||||
# just supports dispatch_use_fp8 = True now!
|
||||
assert use_fp8 is True
|
||||
recv_hidden_states, recv_expert_count, handle, dispatch_hook = (
|
||||
self.ep_engine.low_latency_dispatch_two_stage(x, topk_idx, topk_weights, expertwise_scale, use_fp8)
|
||||
)
|
||||
if dispatch_hook is not None:
|
||||
dispatch_hook()
|
||||
|
||||
return recv_hidden_states, recv_expert_count, handle
|
||||
|
||||
def combine(self, ffn_out, topk_idx, topk_weights, handle):
|
||||
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine(
|
||||
ffn_out, topk_idx, topk_weights, handle
|
||||
)
|
||||
if not self.use_internode_ll_two_stage:
|
||||
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine(
|
||||
ffn_out, topk_idx, topk_weights, handle
|
||||
)
|
||||
else:
|
||||
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine_two_stage(
|
||||
ffn_out, topk_idx, topk_weights, True, handle # just supports dispatch_use_fp8 = True now!
|
||||
)
|
||||
if combine_hook is not None:
|
||||
combine_hook()
|
||||
|
||||
|
@@ -40,62 +40,53 @@ class MoEMethodBase(QuantMethodBase):
|
||||
"down_proj_weight_scale",
|
||||
]
|
||||
self.pack_num = 1
|
||||
self.ep_prefill_runner = None
|
||||
self.ep_decoder_runner = None
|
||||
|
||||
def init_ep(self, layer: nn.Layer) -> None:
|
||||
"""
|
||||
Init EP related module
|
||||
Initialize EP (Expert Parallel) related modules.
|
||||
"""
|
||||
if layer.ep_size > 1:
|
||||
if layer.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
from .ep import EPDecoderRunner, EPPrefillRunner
|
||||
if layer.ep_size <= 1:
|
||||
return
|
||||
|
||||
self.ep_prefill_runner = EPPrefillRunner(
|
||||
layer.top_k,
|
||||
layer.hidden_size,
|
||||
layer.num_experts,
|
||||
layer.fd_config.parallel_config.splitwise_role,
|
||||
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||
layer.ep_size,
|
||||
layer.ep_rank,
|
||||
layer.fd_config.model_config.redundant_experts_num,
|
||||
)
|
||||
self.ep_decoder_runner = EPDecoderRunner(
|
||||
layer.top_k,
|
||||
layer.hidden_size,
|
||||
layer.num_experts,
|
||||
layer.fd_config.parallel_config.splitwise_role,
|
||||
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||
layer.ep_size,
|
||||
layer.ep_rank,
|
||||
layer.fd_config.model_config.redundant_experts_num,
|
||||
)
|
||||
# Lazy import to avoid circular dependency or unnecessary loading
|
||||
from .ep import EPDecoderRunner, EPPrefillRunner
|
||||
|
||||
# Common arguments for both runners
|
||||
common_args = {
|
||||
"top_k": layer.top_k,
|
||||
"hidden_size": layer.hidden_size,
|
||||
"num_experts": layer.num_experts,
|
||||
"splitwise_role": layer.fd_config.parallel_config.splitwise_role,
|
||||
"num_max_dispatch_tokens_per_rank": layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||
"ep_size": layer.ep_size,
|
||||
"ep_rank": layer.ep_rank,
|
||||
"redundant_experts_num": layer.fd_config.model_config.redundant_experts_num,
|
||||
"ep_group": layer.fd_config.parallel_config.ep_group,
|
||||
"use_internode_ll_two_stage": layer.fd_config.parallel_config.use_internode_ll_two_stage,
|
||||
}
|
||||
|
||||
config = layer.fd_config
|
||||
splitwise_role = config.parallel_config.splitwise_role
|
||||
load_strategy = config.load_config.load_strategy
|
||||
|
||||
# For "mixed" splitwise role: conditionally initialize both or none
|
||||
if splitwise_role == "mixed":
|
||||
if load_strategy == "meta":
|
||||
# for RL init model without deepep buff
|
||||
return
|
||||
else:
|
||||
if layer.fd_config.parallel_config.moe_phase.phase == "prefill":
|
||||
from .ep import EPPrefillRunner
|
||||
self.ep_prefill_runner = EPPrefillRunner(**common_args)
|
||||
self.ep_decoder_runner = EPDecoderRunner(**common_args)
|
||||
return
|
||||
|
||||
self.ep_prefill_runner = EPPrefillRunner(
|
||||
layer.top_k,
|
||||
layer.hidden_size,
|
||||
layer.num_experts,
|
||||
layer.fd_config.parallel_config.splitwise_role,
|
||||
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||
layer.ep_size,
|
||||
layer.ep_rank,
|
||||
layer.fd_config.model_config.redundant_experts_num,
|
||||
)
|
||||
else:
|
||||
from .ep import EPDecoderRunner
|
||||
|
||||
self.ep_decoder_runner = EPDecoderRunner(
|
||||
layer.top_k,
|
||||
layer.hidden_size,
|
||||
layer.num_experts,
|
||||
layer.fd_config.parallel_config.splitwise_role,
|
||||
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
|
||||
layer.ep_size,
|
||||
layer.ep_rank,
|
||||
layer.fd_config.model_config.redundant_experts_num,
|
||||
)
|
||||
# For non-mixed ep
|
||||
phase = config.parallel_config.moe_phase.phase
|
||||
if phase == "prefill":
|
||||
self.ep_prefill_runner = EPPrefillRunner(**common_args)
|
||||
else:
|
||||
self.ep_decoder_runner = EPDecoderRunner(**common_args)
|
||||
|
||||
def process_loaded_weights(self, layer, weights) -> None:
|
||||
"""
|
||||
|
@@ -27,11 +27,7 @@ from ..utils import get_tensor
|
||||
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
moe_expert_dispatch,
|
||||
moe_expert_reduce,
|
||||
noaux_tc,
|
||||
)
|
||||
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
|
||||
@@ -46,31 +42,6 @@ elif current_platform.is_iluvatar():
|
||||
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
|
||||
|
||||
|
||||
# used for deepseek_v3
|
||||
def get_moe_scores(
|
||||
gating_output: paddle.Tensor,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
scores_with_bias = scores + e_score_correction_bias
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
"""
|
||||
Use Cutlass Group Gemm to compute Fused MoE.
|
||||
@@ -154,7 +125,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
|
||||
# 3. Compute ffn
|
||||
if token_all_num > 0:
|
||||
logger.info(f"token_all_num {token_all_num}")
|
||||
logger.debug(f"token_all_num {token_all_num}")
|
||||
(
|
||||
permute_input,
|
||||
permute_indices_per_token,
|
||||
@@ -255,6 +226,8 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
"""
|
||||
gate_out = gate(x.cast("float32"))
|
||||
if layer.topk_method == "noaux_tc":
|
||||
from .moe import get_moe_scores
|
||||
|
||||
gate_out, _, _ = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
|
@@ -319,7 +319,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
|
||||
# 4. Compute ffn
|
||||
if token_all_num > 0:
|
||||
logger.info(f"token_all_num {token_all_num}")
|
||||
logger.debug(f"token_all_num {token_all_num}")
|
||||
(recv_x, recv_x_scale) = recv_x
|
||||
|
||||
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
|
||||
@@ -481,7 +481,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
gate_out = gate(x.cast("float32"))
|
||||
|
||||
if layer.topk_method == "noaux_tc":
|
||||
from .ep import get_moe_scores
|
||||
from .moe import get_moe_scores
|
||||
|
||||
_, topk_weights, topk_ids = get_moe_scores(
|
||||
gate_out,
|
||||
|
@@ -21,37 +21,12 @@ import fastdeploy
|
||||
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
MoeWna16MarlinGemmApi,
|
||||
noaux_tc,
|
||||
tritonmoe_preprocess_func,
|
||||
)
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
|
||||
|
||||
def get_moe_scores(
|
||||
gating_output: paddle.Tensor,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
def gptq_marlin_moe_repack(
|
||||
b_q_weight: paddle.Tensor,
|
||||
perm: paddle.Tensor,
|
||||
@@ -279,6 +254,8 @@ class MarlinWeightOnlyMoEMethod(QuantMethodBase):
|
||||
topk_method = layer.topk_method
|
||||
|
||||
if topk_method == "noaux_tc":
|
||||
from .moe import get_moe_scores
|
||||
|
||||
gate_out, _, _ = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
|
@@ -24,6 +24,7 @@ from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
|
||||
from fastdeploy.utils import ceil_div
|
||||
|
||||
from ..quantization.quant_base import QuantMethodBase
|
||||
from .moe import get_moe_scores
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func
|
||||
@@ -167,13 +168,24 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
top_k,
|
||||
True, # apply_norm_weight,
|
||||
False,
|
||||
)
|
||||
if layer.topk_method == "noaux_tc":
|
||||
_, topk_weights, topk_ids = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
layer.topk_group,
|
||||
layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
)
|
||||
else:
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
layer.top_k,
|
||||
True, # apply_norm_weight
|
||||
False,
|
||||
)
|
||||
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
@@ -419,13 +431,25 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
hidden_size = layer.hidden_size
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
top_k,
|
||||
True, # apply_norm_weight,
|
||||
False,
|
||||
)
|
||||
if layer.topk_method == "noaux_tc":
|
||||
|
||||
_, topk_weights, topk_ids = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
layer.topk_group,
|
||||
layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
)
|
||||
else:
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
top_k,
|
||||
True, # apply_norm_weight,
|
||||
False,
|
||||
)
|
||||
|
||||
up_gate_proj_out = paddle.empty(
|
||||
[token_num * top_k, moe_intermediate_size * 2],
|
||||
@@ -671,7 +695,7 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
layer,
|
||||
down_proj_weight_name,
|
||||
layer.create_parameter(
|
||||
shape=self.up_gate_proj_weight_shape,
|
||||
shape=self.down_proj_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
),
|
||||
@@ -829,13 +853,23 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
|
||||
E, N1, _ = getattr(layer, self.added_weight_attrs[0]).shape
|
||||
N2 = getattr(layer, self.added_weight_attrs[1]).shape[1]
|
||||
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
layer.top_k,
|
||||
True, # apply_norm_weight
|
||||
False,
|
||||
)
|
||||
if layer.topk_method == "noaux_tc":
|
||||
_, topk_weights, topk_ids = get_moe_scores(
|
||||
gate_out,
|
||||
layer.n_group,
|
||||
layer.topk_group,
|
||||
layer.top_k,
|
||||
layer.routed_scaling_factor,
|
||||
layer.gate_correction_bias,
|
||||
)
|
||||
else:
|
||||
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
|
||||
gate_out,
|
||||
layer.gate_correction_bias,
|
||||
layer.top_k,
|
||||
True, # apply_norm_weight
|
||||
False,
|
||||
)
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
|
@@ -27,6 +27,11 @@ from fastdeploy.model_executor.utils import slice_fn
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.worker.experts_manager import RedundantExpertManger
|
||||
|
||||
try:
|
||||
from fastdeploy.model_executor.ops.gpu import noaux_tc
|
||||
except:
|
||||
logger.warning("import noaux_tc Failed!")
|
||||
|
||||
|
||||
def get_moe_method():
|
||||
"""
|
||||
@@ -54,6 +59,31 @@ def get_moe_method():
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_moe_scores(
|
||||
gating_output: paddle.Tensor,
|
||||
n_group,
|
||||
topk_group,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
assert e_score_correction_bias is not None, "e_score_correction_bias is none!"
|
||||
scores_with_bias = scores + e_score_correction_bias
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
n_group if n_group > 0 else 1,
|
||||
topk_group if topk_group > 0 else 1,
|
||||
top_k,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
return scores, topk_values, topk_idx
|
||||
|
||||
|
||||
class FusedMoE(nn.Layer):
|
||||
"""
|
||||
FusedMoE is a layer that performs MoE (Mixture of Experts) computation.
|
||||
@@ -176,17 +206,24 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
if shard_id is None:
|
||||
# 1.gate up fused in disk
|
||||
model_format = getattr(param, "model_format", "")
|
||||
is_torch_model = model_format == "torch"
|
||||
output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]]
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
("gate", 0, output_size // 2 * self.tp_size),
|
||||
("up", output_size // 2 * self.tp_size, output_size // 2 * self.tp_size),
|
||||
]
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
loaded_weight_shard = slice_fn(
|
||||
loaded_weight, SHARD_ID_TO_SHARDED_DIM[shard_id], shard_offset, shard_offset + shard_size
|
||||
)
|
||||
self.weight_loader(param, loaded_weight_shard, expert_id, shard_id)
|
||||
per_rank = output_size // 2
|
||||
start = self.tp_rank * per_rank
|
||||
loaded_weight_shard_gate = slice_fn(
|
||||
loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank
|
||||
)
|
||||
self._load_gate_up_weight(
|
||||
param, expert_id, loaded_weight_shard_gate, "gate", SHARD_ID_TO_SHARDED_DIM["gate"], is_sharded=True
|
||||
)
|
||||
start_up = output_size // 2 * self.tp_size + self.tp_rank * per_rank
|
||||
loaded_weight_shard_up = slice_fn(
|
||||
loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank
|
||||
)
|
||||
self._load_gate_up_weight(
|
||||
param, expert_id, loaded_weight_shard_up, "up", SHARD_ID_TO_SHARDED_DIM["up"], is_sharded=True
|
||||
)
|
||||
else:
|
||||
# 2.gate up splited in disk
|
||||
assert shard_id in ["gate", "down", "up"]
|
||||
@@ -198,22 +235,23 @@ class FusedMoE(nn.Layer):
|
||||
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
|
||||
)
|
||||
|
||||
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
|
||||
def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None, is_sharded=False):
|
||||
model_format = getattr(param, "model_format", "")
|
||||
if model_format == "torch":
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
dim = -1 if shard_dim else 0
|
||||
if self.tp_size > 1:
|
||||
is_torch_model = model_format == "torch"
|
||||
if self.tp_size > 1 and not is_sharded:
|
||||
tp_shard_dim = is_torch_model ^ shard_dim
|
||||
weight_dim = -1 if tp_shard_dim else 0
|
||||
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
|
||||
size = loaded_weight.shape[dim]
|
||||
size = loaded_weight.shape[weight_dim]
|
||||
else:
|
||||
size = loaded_weight.get_shape()[dim]
|
||||
size = loaded_weight.get_shape()[weight_dim]
|
||||
block_size = size // self.tp_size
|
||||
shard_offset = self.tp_rank * block_size
|
||||
shard_size = (self.tp_rank + 1) * block_size
|
||||
loaded_weight = slice_fn(loaded_weight, shard_dim, shard_offset, shard_size)
|
||||
|
||||
loaded_weight = slice_fn(loaded_weight, tp_shard_dim, shard_offset, shard_size)
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
expert_param = param[expert_id - self.expert_id_offset]
|
||||
dim = -1 if shard_dim else 0
|
||||
param_shard_size = expert_param.shape[dim] // 2
|
||||
if shard_id == "gate":
|
||||
param_shard_offset = 0
|
||||
@@ -232,9 +270,8 @@ class FusedMoE(nn.Layer):
|
||||
)
|
||||
|
||||
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU
|
||||
if current_platform.is_xpu() or current_platform.is_gcu():
|
||||
if expert_param.shape != loaded_weight.shape:
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
if expert_param.shape != loaded_weight.shape:
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
assert expert_param.shape == loaded_weight.shape, (
|
||||
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
|
||||
)
|
||||
@@ -242,26 +279,26 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None):
|
||||
model_format = getattr(param, "model_format", "")
|
||||
if model_format == "torch":
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
is_torch_model = model_format == "torch"
|
||||
if self.tp_size > 1 and shard_dim is not None:
|
||||
dim = -1 if shard_dim else 0
|
||||
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
|
||||
tp_shard_dim = is_torch_model ^ shard_dim
|
||||
dim = -1 if tp_shard_dim else 0
|
||||
if isinstance(loaded_weight, paddle.Tensor):
|
||||
size = loaded_weight.shape[dim]
|
||||
else:
|
||||
size = loaded_weight.get_shape()[dim]
|
||||
block_size = size // self.tp_size
|
||||
shard_offset = self.tp_rank * block_size
|
||||
shard_size = (self.tp_rank + 1) * block_size
|
||||
loaded_weight = slice_fn(loaded_weight, shard_dim, shard_offset, shard_size)
|
||||
loaded_weight = slice_fn(loaded_weight, tp_shard_dim, shard_offset, shard_size)
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
expert_param = param[expert_id - self.expert_id_offset]
|
||||
if hasattr(param, "tensor_track"):
|
||||
# for dyn quant
|
||||
param.tensor_track.mark(start=0, batch_id=expert_id - self.expert_id_offset)
|
||||
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU
|
||||
if current_platform.is_xpu or current_platform.is_gcu():
|
||||
if expert_param.shape != loaded_weight.shape:
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU and opensource weight
|
||||
if expert_param.shape != loaded_weight.shape:
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
assert expert_param.shape == loaded_weight.shape, (
|
||||
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
|
||||
)
|
||||
|
@@ -18,7 +18,7 @@ import paddle
|
||||
from paddle import nn
|
||||
from paddle.distributed import fleet
|
||||
|
||||
from fastdeploy.model_executor.utils import set_weight_attrs
|
||||
from fastdeploy.model_executor.utils import default_weight_loader, set_weight_attrs
|
||||
|
||||
from .utils import get_tensor
|
||||
|
||||
@@ -53,44 +53,61 @@ class ParallelEHProjection(nn.Layer):
|
||||
self.bias_key = prefix + ".bias"
|
||||
else:
|
||||
self.bias_key = None
|
||||
self.use_ep = fd_config.parallel_config.use_ep
|
||||
self.fd_config = fd_config
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
self.column_cut = True
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
|
||||
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
|
||||
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
|
||||
|
||||
if self.use_ep:
|
||||
self.weight = self.create_parameter(
|
||||
shape=[embedding_dim, num_embeddings],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
is_bias=False,
|
||||
if self.column_cut:
|
||||
need_gather = True
|
||||
self.linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
else:
|
||||
if self.column_cut:
|
||||
need_gather = True
|
||||
self.linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
if self.bias_key is not None:
|
||||
set_weight_attrs(
|
||||
self.linear.bias,
|
||||
{"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}},
|
||||
)
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
if self.bias_key is not None:
|
||||
set_weight_attrs(self.linear.bias, {"output_dim": True})
|
||||
else:
|
||||
self.linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
||||
else:
|
||||
self.linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False, # False diff更小
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
set_weight_attrs(
|
||||
self.linear.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
@@ -100,17 +117,14 @@ class ParallelEHProjection(nn.Layer):
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
"""
|
||||
|
||||
if self.use_ep:
|
||||
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
|
||||
else:
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
|
||||
if self.linear.weight.shape != weight_tensor.shape:
|
||||
weight_tensor = weight_tensor.transpose([1, 0])
|
||||
self.linear.weight.set_value(weight_tensor)
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
|
||||
if self.linear.weight.shape != weight_tensor.shape:
|
||||
weight_tensor = weight_tensor.transpose([1, 0])
|
||||
self.linear.weight.set_value(weight_tensor)
|
||||
|
||||
if self.bias_key is not None:
|
||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
|
||||
self.linear.bias.set_value(bias)
|
||||
if self.bias_key is not None:
|
||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
|
||||
self.linear.bias.set_value(bias)
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
@@ -123,8 +137,5 @@ class ParallelEHProjection(nn.Layer):
|
||||
Tensor: The output tensor after processing through the layer.
|
||||
"""
|
||||
logits = input
|
||||
if self.use_ep:
|
||||
logits = paddle.matmul(logits, self.weight)
|
||||
else:
|
||||
logits = self.linear(logits)
|
||||
logits = self.linear(logits)
|
||||
return logits
|
||||
|
@@ -33,6 +33,7 @@ class KvCacheQuantzationTypes(str, Enum):
|
||||
|
||||
INT8 = "int8"
|
||||
FP8 = "float8_e4m3fn"
|
||||
BLOCK_WISE_FP8 = "block_wise_fp8"
|
||||
INT8_ZP = "int8_zp"
|
||||
INT4_ZP = "int4_zp"
|
||||
FP8_ZP = "float8_e4m3fn_zp"
|
||||
@@ -62,7 +63,11 @@ class KvCacheQuantConfig(QuantConfigBase):
|
||||
|
||||
if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP:
|
||||
self.max_bound = 127.0
|
||||
elif self.quant_type == KvCacheQuantzationTypes.FP8 or self.quant_type == KvCacheQuantzationTypes.FP8_ZP:
|
||||
elif (
|
||||
self.quant_type == KvCacheQuantzationTypes.FP8
|
||||
or self.quant_type == KvCacheQuantzationTypes.FP8_ZP
|
||||
or self.quant_type == KvCacheQuantzationTypes.BLOCK_WISE_FP8
|
||||
):
|
||||
self.max_bound = 448.0
|
||||
elif self.quant_type == KvCacheQuantzationTypes.INT4_ZP:
|
||||
self.max_bound = 7.0
|
||||
@@ -178,12 +183,17 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
layer.cache_quant_type_str = "cache_int4_zp"
|
||||
layer.quant_max_bound = 7.0
|
||||
layer.quant_min_bound = -7.0
|
||||
elif self.cache_quant_config.quant_type == KvCacheQuantzationTypes.BLOCK_WISE_FP8:
|
||||
layer.cache_quant_type_str = "block_wise_fp8"
|
||||
layer.quant_max_bound = 448.0
|
||||
layer.quant_min_bound = -448.0
|
||||
else:
|
||||
raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented")
|
||||
|
||||
self.load_scale(layer, state_dict)
|
||||
if self.cache_quant_config.has_zero_point:
|
||||
self.load_zp(layer, state_dict)
|
||||
if "block_wise" not in layer.cache_quant_type_str:
|
||||
self.load_scale(layer, state_dict)
|
||||
if self.cache_quant_config.has_zero_point:
|
||||
self.load_zp(layer, state_dict)
|
||||
|
||||
def apply(self, layer):
|
||||
"""
|
||||
|
@@ -193,14 +193,6 @@ def create_hadamard_matrix(hidden_size: int) -> paddle.Tensor:
|
||||
return hadamard_matrix
|
||||
|
||||
|
||||
create_hadamard_matrix_map = {}
|
||||
# Zkk: below key are used in 4.5T fp8.
|
||||
create_hadamard_matrix_map[8192] = create_hadamard_matrix(8192)
|
||||
create_hadamard_matrix_map[448] = create_hadamard_matrix(448)
|
||||
create_hadamard_matrix_map[1024] = create_hadamard_matrix(1024)
|
||||
create_hadamard_matrix_map[3584] = create_hadamard_matrix(3584)
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""
|
||||
Ensure the numerator is divisible by the denominator.
|
||||
|
@@ -29,7 +29,6 @@ from safetensors import safe_open
|
||||
from tqdm import tqdm
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.models.tp_utils import (
|
||||
check_tensor_parallel_prerequisites,
|
||||
)
|
||||
@@ -61,7 +60,7 @@ def load_reordered_experts(model_path: str, key_name: str):
|
||||
return weight
|
||||
|
||||
|
||||
def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool = False):
|
||||
def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfig, return_numpy: bool = False):
|
||||
"""
|
||||
load ep checkpoint
|
||||
"""
|
||||
@@ -139,6 +138,10 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool
|
||||
if k in weight_list:
|
||||
filtered_map[k] = weight_list[k]
|
||||
|
||||
if fd_config.parallel_config.tensor_parallel_size > 1:
|
||||
tp_actions = cls._get_tensor_parallel_mappings(fd_config.model_config.pretrained_config)
|
||||
new_actions = {k: v for k, v in tp_actions.items() if k not in num_local_ffn_keys}
|
||||
|
||||
state_dict = {}
|
||||
# Get all safetensor file paths that need to be opened
|
||||
safetensor_paths = set(filtered_map.values())
|
||||
@@ -154,6 +157,9 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool
|
||||
for k in filtered_map:
|
||||
if filtered_map[k] == safetensor_path and k in f.keys():
|
||||
weight = f.get_tensor(k)
|
||||
if fd_config.parallel_config.tensor_parallel_size > 1:
|
||||
if k in new_actions:
|
||||
weight = new_actions[k](weight)
|
||||
if not return_numpy:
|
||||
weight = paddle.Tensor(weight, zero_copy=True)
|
||||
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
|
||||
@@ -186,8 +192,7 @@ def fast_weights_iterator(safe_tensor_list: list[str]):
|
||||
with fast_safe_open(st_file, framework="np") as f:
|
||||
for name in f.keys():
|
||||
param_slice = f.get_slice(name)
|
||||
paddle_tensor = get_tensor(param_slice)
|
||||
yield name, paddle_tensor
|
||||
yield name, param_slice
|
||||
|
||||
|
||||
def fastsafetensors_weights_iterator(
|
||||
@@ -326,12 +331,8 @@ def load_composite_checkpoint(
|
||||
# 3. Pre-sharded (pre-split)
|
||||
"""
|
||||
# (TODO: remove in the future)
|
||||
if (
|
||||
fd_config.parallel_config.use_ep
|
||||
and fd_config.speculative_config.model_type != "mtp"
|
||||
and fd_config.parallel_config.tensor_parallel_size == 1
|
||||
):
|
||||
state_dict = load_ep_checkpoint(model_path, fd_config, return_numpy=True)
|
||||
if fd_config.parallel_config.use_ep and fd_config.speculative_config.model_type != "mtp":
|
||||
state_dict = load_ep_checkpoint(cls, model_path, fd_config, return_numpy=True)
|
||||
else:
|
||||
rank_dirs = [
|
||||
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
|
||||
|
@@ -71,6 +71,11 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
# register rl model
|
||||
import fastdeploy.rl # noqa
|
||||
|
||||
if fd_config.speculative_config.model_type != "mtp":
|
||||
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
|
||||
else:
|
||||
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
|
||||
|
||||
architectures = architectures + "RL"
|
||||
context = paddle.LazyGuard()
|
||||
else:
|
||||
|
@@ -59,6 +59,11 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
||||
# register rl model
|
||||
import fastdeploy.rl # noqa
|
||||
|
||||
if fd_config.speculative_config.model_type != "mtp":
|
||||
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
|
||||
else:
|
||||
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
|
||||
|
||||
architectures = architectures + "RL"
|
||||
|
||||
with context:
|
||||
|
@@ -637,6 +637,19 @@ class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM):
|
||||
return "Ernie4_5_ForCausalLM"
|
||||
|
||||
|
||||
class Ernie4_5ForCausalLM(Ernie4_5_ForCausalLM):
|
||||
"""
|
||||
Ernie4_5ForCausalLM 0.3B-PT
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def name(self):
|
||||
"""
|
||||
Model Architecture Name
|
||||
"""
|
||||
return "Ernie4_5ForCausalLM"
|
||||
|
||||
|
||||
class Ernie4_5_MoePretrainedModel(PretrainedModel):
|
||||
"""
|
||||
Ernie4_5_MoePretrainedModel
|
||||
@@ -788,3 +801,16 @@ class Ernie4_5_PretrainedModel(Ernie4_5_MoePretrainedModel):
|
||||
Model Architecture Name
|
||||
"""
|
||||
return "Ernie4_5_ForCausalLM"
|
||||
|
||||
|
||||
class Ernie4_5PretrainedModel(Ernie4_5_PretrainedModel):
|
||||
"""
|
||||
Ernie4_5PretrainedModel 0.3B-PT
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def arch_name(self):
|
||||
"""
|
||||
Model Architecture Name
|
||||
"""
|
||||
return "Ernie4_5ForCausalLM"
|
||||
|
@@ -23,7 +23,7 @@ from paddle import nn
|
||||
from paddle.autograd import PyLayer
|
||||
from paddle.distributed.fleet.utils import recompute
|
||||
|
||||
from fastdeploy.model_executor.layers.utils import _set_var_distributed, get_tensor
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.dist_utils import (
|
||||
RowSequenceParallelLinear,
|
||||
all_gather_group,
|
||||
@@ -197,19 +197,6 @@ class VariableResolutionResamplerModel(nn.Layer):
|
||||
self.after_norm = RMSNorm(out_config)
|
||||
|
||||
if self.tensor_parallel_degree > 1:
|
||||
for idx in [2, 3]:
|
||||
mark_as_sequence_parallel_parameter(self.spatial_linear[idx].weight)
|
||||
mark_as_sequence_parallel_parameter(self.spatial_linear[idx].bias)
|
||||
_set_var_distributed(self.spatial_linear[idx].weight, split_axis=0)
|
||||
_set_var_distributed(self.spatial_linear[idx].bias, split_axis=0)
|
||||
if self.use_temporal_conv:
|
||||
for idx in [0, 2, 3]:
|
||||
mark_as_sequence_parallel_parameter(self.temporal_linear[idx].weight)
|
||||
mark_as_sequence_parallel_parameter(self.temporal_linear[idx].bias)
|
||||
|
||||
mark_as_sequence_parallel_parameter(self.mlp.weight)
|
||||
mark_as_sequence_parallel_parameter(self.mlp.bias)
|
||||
mark_as_sequence_parallel_parameter(self.after_norm.weight)
|
||||
set_weight_attrs(self.spatial_linear[0].weight, {"output_dim": False})
|
||||
|
||||
def spatial_conv_reshape(self, x, spatial_conv_size):
|
||||
|
@@ -306,7 +306,9 @@ def post_process_normal(
|
||||
)
|
||||
|
||||
|
||||
def post_process_specualate(model_output, save_each_rank: bool = False, skip_save_output: bool = False):
|
||||
def post_process_specualate(
|
||||
model_output: ModelOutputData, save_each_rank: bool = False, skip_save_output: bool = False
|
||||
):
|
||||
""""""
|
||||
speculate_update(
|
||||
model_output.seq_lens_encoder,
|
||||
|
@@ -160,6 +160,7 @@ def default_weight_loader(fd_config: FDConfig) -> None:
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
model_format = getattr(param, "model_format", "")
|
||||
if model_format == "torch":
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if output_dim is not None and fd_config.parallel_config.tensor_parallel_size > 1:
|
||||
|
@@ -24,6 +24,7 @@ class MultimodalRegistry:
|
||||
"Ernie4_5_VLMoeForConditionalGeneration",
|
||||
"Ernie5MoeForCausalLM",
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"Ernie5ForCausalLM",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
@@ -161,24 +161,23 @@ class TokenProcessor:
|
||||
continue
|
||||
|
||||
else:
|
||||
if (
|
||||
if self.use_logprobs:
|
||||
get_output_topk(
|
||||
self.output_tokens,
|
||||
self.output_scores,
|
||||
self.output_ranks,
|
||||
K,
|
||||
rank_id,
|
||||
is_blocking,
|
||||
)
|
||||
elif (
|
||||
self.cfg.parallel_config.enable_expert_parallel
|
||||
and self.cfg.parallel_config.data_parallel_size > 1
|
||||
):
|
||||
get_output_ep(self.output_tokens, rank_id, is_blocking)
|
||||
|
||||
else:
|
||||
if self.use_logprobs:
|
||||
get_output_topk(
|
||||
self.output_tokens,
|
||||
self.output_scores,
|
||||
self.output_ranks,
|
||||
K,
|
||||
rank_id,
|
||||
is_blocking,
|
||||
)
|
||||
else:
|
||||
get_output(self.output_tokens, rank_id, is_blocking)
|
||||
get_output(self.output_tokens, rank_id, is_blocking)
|
||||
|
||||
if self.output_tokens[0, 0] == -2:
|
||||
continue
|
||||
@@ -261,7 +260,7 @@ class TokenProcessor:
|
||||
|
||||
def _compute_speculative_status(self):
|
||||
# TODO(liuzichang): Supplement more statistics
|
||||
interval = 10
|
||||
interval = 1
|
||||
if self.speculative_stats_step % interval == 0:
|
||||
accept_ratio = 1 - self.total_step * 1.0 / self.number_of_output_tokens
|
||||
spec_logger.info(
|
||||
@@ -333,6 +332,9 @@ class TokenProcessor:
|
||||
+ accept_num[i]
|
||||
].tolist()
|
||||
if len(token_ids) == 0 or token_ids[-1] <= 0:
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
|
||||
self.resource_manager.reschedule_preempt_task(task_id)
|
||||
continue
|
||||
else:
|
||||
token_id = int(tokens[i, 0])
|
||||
@@ -517,6 +519,31 @@ class TokenProcessor:
|
||||
single_head_acceptance_rate
|
||||
)
|
||||
|
||||
def clear_data(self):
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.resource_manager.clear_data()
|
||||
for i in range(self.cfg.max_num_seqs):
|
||||
if self.resource_manager.stop_flags[i]:
|
||||
continue
|
||||
task = self.resource_manager.tasks_list[i]
|
||||
result = RequestOutput(
|
||||
request_id=task.request_id,
|
||||
outputs=CompletionOutput(
|
||||
index=i,
|
||||
send_idx=self.tokens_counter[task.request_id],
|
||||
token_ids=task.eos_token_ids,
|
||||
draft_token_ids=[],
|
||||
),
|
||||
finished=True,
|
||||
metrics=RequestMetrics(
|
||||
arrival_time=time.time(),
|
||||
request_start_time=task.arrival_time,
|
||||
),
|
||||
)
|
||||
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
|
||||
self._recycle_resources(task.request_id, i, task, result, is_prefill)
|
||||
llm_logger.warning(f"clear data for task {task.request_id}")
|
||||
|
||||
|
||||
class WarmUpTokenProcessor(TokenProcessor):
|
||||
"""
|
||||
|
@@ -17,5 +17,11 @@
|
||||
from .input_processor import load_input_processor_plugins
|
||||
from .model_register import load_model_register_plugins
|
||||
from .model_runner import load_model_runner_plugins
|
||||
from .reasoning_parser import load_reasoning_parser_plugins
|
||||
|
||||
__all__ = ["load_model_register_plugins", "load_model_runner_plugins", "load_input_processor_plugins"]
|
||||
__all__ = [
|
||||
"load_model_register_plugins",
|
||||
"load_model_runner_plugins",
|
||||
"load_input_processor_plugins",
|
||||
"load_reasoning_parser_plugins",
|
||||
]
|
||||
|
@@ -23,5 +23,5 @@ PLUGINS_GROUP = "fastdeploy.input_processor_plugins"
|
||||
def load_input_processor_plugins():
|
||||
"""load_input_processor_plugins"""
|
||||
plugins = load_plugins_by_group(group=PLUGINS_GROUP)
|
||||
assert len(plugins) <= 1, "Most one plugin is allowed to be loaded."
|
||||
assert len(plugins) == 1, "Only one plugin is allowed to be loaded."
|
||||
return next(iter(plugins.values()))()
|
||||
|
@@ -14,9 +14,10 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from fastdeploy.plugins.utils import load_plugins_by_group, plugins_loaded
|
||||
from fastdeploy.plugins.utils import load_plugins_by_group
|
||||
|
||||
# make sure one process only loads plugins once
|
||||
plugins_loaded = False
|
||||
PLUGINS_GROUP = "fastdeploy.model_register_plugins"
|
||||
|
||||
|
||||
|
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from fastdeploy.plugins.utils import load_plugins_by_group, plugins_loaded
|
||||
from fastdeploy.plugins.utils import load_plugins_by_group
|
||||
|
||||
# use for modle runner
|
||||
PLUGINS_GROUP = "fastdeploy.model_runner_plugins"
|
||||
@@ -22,11 +22,6 @@ PLUGINS_GROUP = "fastdeploy.model_runner_plugins"
|
||||
|
||||
def load_model_runner_plugins():
|
||||
"""load_model_runner_plugins"""
|
||||
global plugins_loaded
|
||||
if plugins_loaded:
|
||||
return
|
||||
plugins_loaded = True
|
||||
|
||||
plugins = load_plugins_by_group(group=PLUGINS_GROUP)
|
||||
assert len(plugins) <= 1, "Most one plugin is allowed to be loaded."
|
||||
assert len(plugins) == 1, "Only one plugin is allowed to be loaded."
|
||||
return next(iter(plugins.values()))()
|
||||
|
34
fastdeploy/plugins/reasoning_parser/__init__.py
Normal file
34
fastdeploy/plugins/reasoning_parser/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from fastdeploy.plugins.utils import load_plugins_by_group
|
||||
|
||||
# make sure one process only loads plugins once
|
||||
plugins_loaded = False
|
||||
PLUGINS_GROUP = "fastdeploy.reasoning_parser_plugins"
|
||||
|
||||
|
||||
def load_reasoning_parser_plugins():
|
||||
"""load_reasoning_parser_plugins"""
|
||||
global plugins_loaded
|
||||
if plugins_loaded:
|
||||
return
|
||||
plugins_loaded = True
|
||||
|
||||
plugins = load_plugins_by_group(group=PLUGINS_GROUP)
|
||||
# general plugins, we only need to execute the loaded functions
|
||||
for func in plugins.values():
|
||||
func()
|
@@ -19,8 +19,6 @@ from typing import Any, Callable
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.utils import llm_logger as logger
|
||||
|
||||
plugins_loaded = False
|
||||
|
||||
|
||||
def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]:
|
||||
import sys
|
||||
|
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from fastdeploy.plugins import load_reasoning_parser_plugins
|
||||
|
||||
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
||||
from .ernie_vl_reasoning_parsers import ErnieVLReasoningParser
|
||||
from .ernie_x1_reasoning_parsers import ErnieX1ReasoningParser
|
||||
@@ -26,3 +28,5 @@ __all__ = [
|
||||
"Qwen3ReasoningParser",
|
||||
"ErnieX1ReasoningParser",
|
||||
]
|
||||
|
||||
load_reasoning_parser_plugins()
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user