Compare commits
29 Commits
Jason/expe
...
remove_use
Author | SHA1 | Date | |
---|---|---|---|
![]() |
91536c8279 | ||
![]() |
f38b174a75 | ||
![]() |
6b47773bd6 | ||
![]() |
0358329946 | ||
![]() |
01f6934162 | ||
![]() |
7bdc6f41e5 | ||
![]() |
bba279cf38 | ||
![]() |
4f460db556 | ||
![]() |
74d7b9151d | ||
![]() |
0fa28b1068 | ||
![]() |
cffde70949 | ||
![]() |
7f9a9b37f3 | ||
![]() |
b41988f4bc | ||
![]() |
7ccbcc5a62 | ||
![]() |
fbb4e0f8d1 | ||
![]() |
4e8ba62241 | ||
![]() |
4874e13e01 | ||
![]() |
7e3148ed81 | ||
![]() |
4f8ff478b3 | ||
![]() |
c4098d56a0 | ||
![]() |
a6b161b007 | ||
![]() |
7272afe3dc | ||
![]() |
dfc94371ee | ||
![]() |
35b8362804 | ||
![]() |
d43c2f2577 | ||
![]() |
14df2c59da | ||
![]() |
934071578a | ||
![]() |
36a58f487c | ||
![]() |
f25bbefea1 |
1
.github/workflows/Codestyle-Check.yml
vendored
@@ -5,7 +5,6 @@ on:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
- 'feature/*'
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
|
2
.github/workflows/_accuracy_test.yml
vendored
@@ -143,7 +143,7 @@ jobs:
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
|
2
.github/workflows/_base_test.yml
vendored
@@ -143,7 +143,7 @@ jobs:
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
|
3
.github/workflows/_build_linux.yml
vendored
@@ -134,6 +134,7 @@ jobs:
|
||||
fi
|
||||
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
chown -R $(whoami) /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
if [[ "${WITH_NIGHTLY_BUILD}" == "ON" ]];then
|
||||
GIT_COMMIT_TIME=$(git --no-pager show -s --format=%ci HEAD)
|
||||
@@ -148,7 +149,7 @@ jobs:
|
||||
elif [[ "${PADDLEVERSION}" != "" ]];then
|
||||
python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
else
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
fi
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
2
.github/workflows/_logprob_test_linux.yml
vendored
@@ -133,7 +133,7 @@ jobs:
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
|
2
.github/workflows/_pre_ce_test.yml
vendored
@@ -142,7 +142,7 @@ jobs:
|
||||
--gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c '
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install ${fd_wheel_url}
|
||||
bash scripts/run_pre_ce.sh
|
||||
'
|
||||
|
2
.github/workflows/_stable_test.yml
vendored
@@ -146,7 +146,7 @@ jobs:
|
||||
-v "${CACHE_DIR}/ConfigDir:/root/.config" \
|
||||
-e TZ="Asia/Shanghai" \
|
||||
--gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc '
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
|
2
.github/workflows/_unit_test_coverage.yml
vendored
@@ -168,7 +168,7 @@ jobs:
|
||||
git config --global --add safe.directory /workspace/FastDeploy
|
||||
cd FastDeploy
|
||||
git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt
|
||||
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
python -m pip install coverage
|
||||
|
1
.github/workflows/approve.yml
vendored
@@ -5,7 +5,6 @@ on:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
- 'feature/*'
|
||||
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
1
.github/workflows/ce_job.yml
vendored
@@ -6,7 +6,6 @@ on:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
- 'feature/experimental_feature*'
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
|
1
.github/workflows/ci_xpu.yml
vendored
@@ -5,7 +5,6 @@ on:
|
||||
branches:
|
||||
- develop
|
||||
- 'release/*'
|
||||
- 'feature/*'
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
|
2
.github/workflows/pr_build_and_test.yml
vendored
@@ -2,7 +2,7 @@ name: PR Build and Test
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
branches: [develop, release/**, feature/**]
|
||||
branches: [develop, release/**]
|
||||
permissions: read-all
|
||||
|
||||
concurrency:
|
||||
|
9
.gitmodules
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
[submodule "custom_ops/third_party/DeepGEMM"]
|
||||
path = custom_ops/third_party/DeepGEMM
|
||||
url = https://github.com/deepseek-ai/DeepGEMM.git
|
||||
[submodule "custom_ops/third_party/cutlass"]
|
||||
path = custom_ops/third_party/cutlass
|
||||
url = https://github.com/NVIDIA/cutlass.git
|
||||
[submodule "custom_ops/third_party/nlohmann_json"]
|
||||
path = custom_ops/third_party/nlohmann_json
|
||||
url = https://github.com/nlohmann/json.git
|
17
README.md
@@ -26,6 +26,8 @@ English | [简体中文](README_CN.md)
|
||||
# FastDeploy : Inference and Deployment Toolkit for LLMs and VLMs based on PaddlePaddle
|
||||
|
||||
## News
|
||||
**[2025-09] 🔥 FastDeploy v2.2 is newly released!** It now offers compatibility with models in the HuggingFace ecosystem, has further optimized performance, and newly adds support for [baidu/ERNIE-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking)!
|
||||
|
||||
**[2025-08] 🔥 Released FastDeploy v2.1:** A brand-new KV Cache scheduling strategy has been introduced, and expanded support for PD separation and CUDA Graph across more models. Enhanced hardware support has been added for platforms like Kunlun and Hygon, along with comprehensive optimizations to improve the performance of both the service and inference engine.
|
||||
|
||||
**[2025-07] The FastDeploy 2.0 Inference Deployment Challenge is now live!** Complete the inference deployment task for the ERNIE 4.5 series open-source models to win official FastDeploy 2.0 merch and generous prizes! 🎁 You're welcome to try it out and share your feedback! 📌[Sign up here](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[Event details](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
|
||||
@@ -57,8 +59,9 @@ FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**,
|
||||
- [Iluvatar GPU](./docs/get_started/installation/iluvatar_gpu.md)
|
||||
- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md)
|
||||
- [Hygon DCU](./docs/get_started/installation/hygon_dcu.md)
|
||||
- [MetaX GPU](./docs/get_started/installation/metax_gpu.md.md)
|
||||
|
||||
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU and MetaX GPU are currently under development and testing. Stay tuned for updates!
|
||||
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU are currently under development and testing. Stay tuned for updates!
|
||||
|
||||
## Get Started
|
||||
|
||||
@@ -68,20 +71,12 @@ Learn how to use FastDeploy through our documentation:
|
||||
- [ERNIE-4.5-VL Multimodal Model Deployment](./docs/get_started/ernie-4.5-vl.md)
|
||||
- [Offline Inference Development](./docs/offline_inference.md)
|
||||
- [Online Service Deployment](./docs/online_serving/README.md)
|
||||
- [Full Supported Models List](./docs/supported_models.md)
|
||||
- [Best Practices](./docs/best_practices/README.md)
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length |
|
||||
|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- |
|
||||
|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅| ✅ |128K |
|
||||
|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|❌| ✅ | 128K |
|
||||
|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K |
|
||||
|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K |
|
||||
|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | ✅ | ✅|128K |
|
||||
|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅|128K |
|
||||
|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅| 128K |
|
||||
Learn how to download models, enable using the torch format, and more:
|
||||
- [Full Supported Models List](./docs/supported_models.md)
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
|
19
README_CN.md
@@ -26,7 +26,9 @@
|
||||
# FastDeploy :基于飞桨的大语言模型与视觉语言模型推理部署工具包
|
||||
|
||||
## 最新活动
|
||||
**[2025-08] 🔥 FastDeploy v2.1 全新发布:** 全新的KV Cache调度策略,更多模型支持PD分离和CUDA Graph,昆仑、海光等更多硬件支持增强,全方面优化服务和推理引擎的性能。
|
||||
**[2025-09] 🔥 FastDeploy v2.2 全新发布**: HuggingFace生态模型兼容,性能进一步优化,更新增对[baidu/ERNIE-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking)支持!
|
||||
|
||||
**[2025-08] FastDeploy v2.1 发布**:全新的KV Cache调度策略,更多模型支持PD分离和CUDA Graph,昆仑、海光等更多硬件支持增强,全方面优化服务和推理引擎的性能。
|
||||
|
||||
**[2025-07] 《FastDeploy2.0推理部署实测》专题活动已上线!** 完成文心4.5系列开源模型的推理部署等任务,即可获得骨瓷马克杯等FastDeploy2.0官方周边及丰富奖金!🎁 欢迎大家体验反馈~ 📌[报名地址](https://www.wjx.top/vm/meSsp3L.aspx#) 📌[活动详情](https://github.com/PaddlePaddle/FastDeploy/discussions/2728)
|
||||
|
||||
@@ -55,8 +57,9 @@ FastDeploy 支持在**英伟达(NVIDIA)GPU**、**昆仑芯(Kunlunxin)XPU
|
||||
- [天数 CoreX](./docs/zh/get_started/installation/iluvatar_gpu.md)
|
||||
- [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md)
|
||||
- [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md)
|
||||
- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md.md)
|
||||
|
||||
**注意:** 我们正在积极拓展硬件支持范围。目前,包括昇腾(Ascend)NPU 和 沐曦(MetaX)GPU 在内的其他硬件平台正在开发测试中。敬请关注更新!
|
||||
**注意:** 我们正在积极拓展硬件支持范围。目前,包括昇腾(Ascend)NPU 等其他硬件平台正在开发测试中。敬请关注更新!
|
||||
|
||||
## 入门指南
|
||||
|
||||
@@ -66,20 +69,12 @@ FastDeploy 支持在**英伟达(NVIDIA)GPU**、**昆仑芯(Kunlunxin)XPU
|
||||
- [ERNIE-4.5-VL 部署](./docs/zh/get_started/ernie-4.5-vl.md)
|
||||
- [离线推理](./docs/zh/offline_inference.md)
|
||||
- [在线服务](./docs/zh/online_serving/README.md)
|
||||
- [模型支持列表](./docs/zh/supported_models.md)
|
||||
- [最佳实践](./docs/zh/best_practices/README.md)
|
||||
|
||||
## 支持模型列表
|
||||
|
||||
| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length |
|
||||
|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- |
|
||||
|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅| ✅ |128K |
|
||||
|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|❌| ✅ | 128K |
|
||||
|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K |
|
||||
|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K |
|
||||
|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | ✅ | ✅|128K |
|
||||
|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅|128K |
|
||||
|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ✅ | ✅ | ✅ | ❌ | ✅| 128K |
|
||||
通过我们的文档了解如何下载模型,如何支持torch格式等:
|
||||
- [模型支持列表](./docs/zh/supported_models.md)
|
||||
|
||||
## 进阶用法
|
||||
|
||||
|
@@ -140,8 +140,8 @@ void AppendAttentionKernel(
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_mask,
|
||||
cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales : cache_k_dequant_scales,
|
||||
cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales : cache_v_dequant_scales,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
@@ -273,15 +273,11 @@ void AppendAttentionKernel(
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
} else {
|
||||
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
|
||||
meta_data,
|
||||
@@ -300,15 +296,11 @@ void AppendAttentionKernel(
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
|
@@ -32,15 +32,14 @@ template <typename T,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
bool IsFP8=false>
|
||||
__global__ void multi_query_append_attention_c8_kernel(
|
||||
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
|
||||
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
|
||||
// head_dim]
|
||||
CacheT *__restrict__ cache_v,
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads]
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int *__restrict__ seq_lens,
|
||||
@@ -92,30 +91,28 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
|
||||
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
T cache_k_scale_reg[num_frags_y * 4];
|
||||
T cache_v_scale_reg[num_frags_y * 2];
|
||||
if (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
}
|
||||
|
||||
const uint32_t q_end =
|
||||
@@ -204,17 +201,6 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
|
||||
v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
|
||||
T* k_smem_scale_ptr = nullptr;
|
||||
T* v_smem_scale_ptr = nullptr;
|
||||
smem_t k_scale_smem;
|
||||
smem_t v_scale_smem;
|
||||
if constexpr (IsDynamicC8) {
|
||||
k_smem_scale_ptr = reinterpret_cast<T*>(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
|
||||
v_smem_scale_ptr = k_smem_scale_ptr + num_frags_z * 16;
|
||||
k_scale_smem.base = reinterpret_cast<b128_t*>(k_smem_scale_ptr);
|
||||
v_scale_smem.base = reinterpret_cast<b128_t*>(v_smem_scale_ptr);
|
||||
}
|
||||
|
||||
|
||||
const uint32_t num_iterations = div_up(
|
||||
@@ -275,20 +261,6 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
k_scale_smem,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
produce_v_blockwise_c8<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -306,34 +278,14 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_v_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
v_scale_smem,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
|
||||
#pragma unroll 1
|
||||
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_k_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
k_smem_scale_ptr,
|
||||
cache_k_scale_reg
|
||||
);
|
||||
}
|
||||
// s = qk
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
|
||||
&qo_smem,
|
||||
&q_smem_offset_r,
|
||||
&k_smem,
|
||||
@@ -366,7 +318,6 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
s_frag, o_frag, m_frag, d_frag);
|
||||
__syncthreads();
|
||||
|
||||
const int ori_kv_idx_base = kv_idx_base;
|
||||
kv_idx_base += num_frags_z * 16;
|
||||
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -384,29 +335,9 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
k_scale_smem,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_v_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
v_smem_scale_ptr,
|
||||
cache_v_scale_reg
|
||||
);
|
||||
}
|
||||
|
||||
// compute sfm*v
|
||||
compute_sfm_v_c8<num_frags_x,
|
||||
@@ -415,9 +346,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
BLOCK_SIZE,
|
||||
T,
|
||||
CacheT,
|
||||
is_scale_channel_wise,
|
||||
IsFP8,
|
||||
IsDynamicC8>(
|
||||
is_scale_channel_wise, IsFP8>(
|
||||
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
|
||||
__syncthreads();
|
||||
|
||||
@@ -437,20 +366,6 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_v_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
v_scale_smem,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
|
||||
}
|
||||
@@ -548,15 +463,14 @@ template <typename T,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool is_scale_channel_wise=false,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
bool IsFP8=false>
|
||||
__global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
|
||||
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
|
||||
// head_dim]
|
||||
CacheT *__restrict__ cache_v,
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int *__restrict__ seq_lens,
|
||||
@@ -608,30 +522,28 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
}
|
||||
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
|
||||
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
T cache_k_scale_reg[num_frags_y * 4];
|
||||
T cache_v_scale_reg[num_frags_y * 2];
|
||||
if (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx];
|
||||
cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1];
|
||||
cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8];
|
||||
cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9];
|
||||
}
|
||||
scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_v_scale_cur_head = cache_v_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
const int scale_idx = i * 16;
|
||||
cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx];
|
||||
cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
}
|
||||
const uint32_t q_end =
|
||||
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
|
||||
@@ -722,17 +634,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
|
||||
v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
|
||||
T* k_smem_scale_ptr = nullptr;
|
||||
T* v_smem_scale_ptr = nullptr;
|
||||
smem_t k_scale_smem;
|
||||
smem_t v_scale_smem;
|
||||
if constexpr (IsDynamicC8) {
|
||||
k_smem_scale_ptr = reinterpret_cast<T*>(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
|
||||
v_smem_scale_ptr = k_smem_scale_ptr + NUM_WARP_KV * num_frags_z * 16;
|
||||
k_scale_smem.base = reinterpret_cast<b128_t*>(k_smem_scale_ptr);
|
||||
v_scale_smem.base = reinterpret_cast<b128_t*>(v_smem_scale_ptr);
|
||||
}
|
||||
|
||||
const uint32_t num_iterations = div_up(
|
||||
CAUSAL
|
||||
@@ -795,20 +696,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
k_scale_smem,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
produce_v_blockwise_c8<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -826,34 +713,14 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_v_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
v_scale_smem,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
#pragma unroll 1
|
||||
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_k_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
k_smem_scale_ptr,
|
||||
cache_k_scale_reg
|
||||
);
|
||||
}
|
||||
|
||||
// s = qk
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
|
||||
&qo_smem,
|
||||
&q_smem_offset_r,
|
||||
&k_smem,
|
||||
@@ -886,7 +753,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
s_frag, o_frag, m_frag, d_frag);
|
||||
__syncthreads();
|
||||
|
||||
const uint32_t ori_kv_idx_base = kv_idx_base;
|
||||
kv_idx_base += NUM_WARP_KV * num_frags_z * 16;
|
||||
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -904,29 +770,9 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
k_scale_smem,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_v_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
v_smem_scale_ptr,
|
||||
cache_v_scale_reg
|
||||
);
|
||||
}
|
||||
|
||||
// compute sfm * v
|
||||
compute_sfm_v_c8_iter_sq_bvec<num_frags_x,
|
||||
@@ -935,9 +781,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
BLOCK_SIZE,
|
||||
T,
|
||||
CacheT,
|
||||
is_scale_channel_wise,
|
||||
IsFP8,
|
||||
IsDynamicC8>(
|
||||
is_scale_channel_wise, IsFP8>(
|
||||
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
|
||||
__syncthreads();
|
||||
|
||||
@@ -957,20 +801,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
kv_idx_base,
|
||||
chunk_end,
|
||||
const_v_offset);
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
|
||||
BLOCK_SIZE,
|
||||
num_frags_z,
|
||||
NUM_WARP_Q>(
|
||||
v_scale_smem,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
commit_group();
|
||||
}
|
||||
wait_group<0>();
|
||||
@@ -1065,8 +895,7 @@ template <typename T,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
bool IsFP8=false>
|
||||
void MultiQueryAppendC8Attention(
|
||||
const AppendAttnMetaData &meta_data,
|
||||
const paddle::Tensor &qkv,
|
||||
@@ -1124,8 +953,7 @@ void MultiQueryAppendC8Attention(
|
||||
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16;
|
||||
constexpr uint32_t smem_size =
|
||||
num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
|
||||
num_frags_z * 16 * sizeof(T) * 2;
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
|
||||
auto split_kv_kernel =
|
||||
multi_query_append_attention_c8_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
@@ -1142,9 +970,7 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
false, IsFP8>;
|
||||
if (is_scale_channel_wise) {
|
||||
split_kv_kernel =
|
||||
multi_query_append_attention_c8_kernel<NV_TYPE,
|
||||
@@ -1162,9 +988,7 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
true, IsFP8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(split_kv_kernel,
|
||||
@@ -1198,9 +1022,7 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
false, IsFP8>;
|
||||
if (is_scale_channel_wise) {
|
||||
nosplit_kv_kernel =
|
||||
multi_query_append_attention_c8_kernel<NV_TYPE,
|
||||
@@ -1218,9 +1040,7 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
true, IsFP8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(nosplit_kv_kernel,
|
||||
@@ -1398,8 +1218,7 @@ void MultiQueryAppendC8Attention(
|
||||
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2;
|
||||
constexpr uint32_t smem_size =
|
||||
num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
|
||||
NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2;
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
|
||||
auto split_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
@@ -1416,9 +1235,7 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
false, IsFP8>;
|
||||
if (is_scale_channel_wise) {
|
||||
split_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
@@ -1436,9 +1253,7 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
true, IsFP8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(split_kv_kernel,
|
||||
@@ -1480,9 +1295,7 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
false, IsFP8>;
|
||||
if (is_scale_channel_wise) {
|
||||
nosplit_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
@@ -1500,9 +1313,7 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
true, IsFP8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(nosplit_kv_kernel,
|
||||
@@ -1735,7 +1546,6 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out) {
|
||||
const auto token_num = meta_data.token_nums;
|
||||
@@ -1744,7 +1554,6 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const auto num_heads = meta_data.q_num_heads;
|
||||
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
|
||||
const auto head_dim = meta_data.head_dims;
|
||||
bool is_dynamic_cfp8 = cache_quant_type_str == "block_wise_fp8";
|
||||
|
||||
DISPATCH_CAUSAL(
|
||||
causal,
|
||||
@@ -1763,46 +1572,43 @@ void CascadeAppendAttentionC8Kernel(
|
||||
BLOCK_SIZE,
|
||||
{DISPATCH_BLOCKSHAPE_Q(
|
||||
block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, {
|
||||
DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, {
|
||||
MultiQueryAppendC8Attention<T,
|
||||
GROUP_SIZE,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
CAUSAL,
|
||||
BLOCK_SHAPE_Q,
|
||||
NUM_WARP_Q,
|
||||
OutT,
|
||||
ENABLE_PREFILL,
|
||||
IsFP8,
|
||||
IsDynamicC8>(
|
||||
meta_data,
|
||||
qkv,
|
||||
cache_k,
|
||||
cache_v,
|
||||
attn_mask,
|
||||
cache_k_scale.get(),
|
||||
cache_v_scale.get(),
|
||||
shift_bias,
|
||||
smooth_weight,
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
is_decoder,
|
||||
stream,
|
||||
out);
|
||||
})})})})})})})
|
||||
MultiQueryAppendC8Attention<T,
|
||||
GROUP_SIZE,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
CAUSAL,
|
||||
BLOCK_SHAPE_Q,
|
||||
NUM_WARP_Q,
|
||||
OutT,
|
||||
ENABLE_PREFILL, IsFP8>(
|
||||
meta_data,
|
||||
qkv,
|
||||
cache_k,
|
||||
cache_v,
|
||||
attn_mask,
|
||||
cache_k_scale.get(),
|
||||
cache_v_scale.get(),
|
||||
shift_bias,
|
||||
smooth_weight,
|
||||
seq_lens_q,
|
||||
seq_lens_kv,
|
||||
seq_lens_encoder,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
block_table,
|
||||
batch_ids,
|
||||
tile_ids_per_batch,
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
max_dec_len,
|
||||
quant_max_bound,
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
max_partition_size,
|
||||
encoder_max_partition_size,
|
||||
speculate_max_draft_token_num,
|
||||
is_decoder,
|
||||
stream,
|
||||
out);
|
||||
})})})})})})
|
||||
}
|
||||
|
@@ -384,105 +384,6 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
|
||||
}
|
||||
}
|
||||
|
||||
template<SharedMemFillMode fill_mode,
|
||||
uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_kv_dynamic_scale_gmem2smem_async(
|
||||
smem_t kv_scale_smem,
|
||||
const int* block_table_now,
|
||||
const T* cache_kv_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end
|
||||
) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
if (tid < block_size / 8) {
|
||||
const T* cache_k_scale_now = cache_kv_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size + tid * 8;
|
||||
const int kv_idx_this_thread = kv_idx + tid * 8;
|
||||
kv_scale_smem.load_128b_async<fill_mode>(tid, cache_k_scale_now, kv_idx_this_thread < chunk_end);
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
if (tid < block_size / 8 * 2) {
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * tid / 8;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const int kv_idx_this_thread = kv_idx + tid * 8;
|
||||
const T* cache_k_scale_now = cache_kv_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size + tid % 8 * 8;
|
||||
kv_scale_smem.load_128b_async<fill_mode>(tid, cache_k_scale_now, kv_idx_this_thread < chunk_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_k_dynamic_scale_smem2reg(
|
||||
T* k_smem_scale,
|
||||
T* cache_k_reg
|
||||
) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
const uint32_t scale_idx = fz * 16 + row_id;
|
||||
cache_k_reg[fz * 2] = k_smem_scale[scale_idx];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
const uint32_t scale_idx = ty * 32 + fz * 16 + row_id;
|
||||
cache_k_reg[fz * 2] = k_smem_scale[scale_idx];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_v_dynamic_scale_smem2reg(
|
||||
T* v_smem_scale,
|
||||
T* cache_v_reg
|
||||
) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t row_id = tx % 4 * 2;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
const uint32_t scale_idx = fz * 16 + row_id;
|
||||
cache_v_reg[fz * 4] = v_smem_scale[scale_idx];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9];
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
const uint32_t row_id = tx % 4 * 2;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
const uint32_t scale_idx = ty * 32 + fz * 16 + row_id;
|
||||
cache_v_reg[fz * 4] = v_smem_scale[scale_idx];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <SharedMemFillMode fill_mode,
|
||||
uint32_t num_warps,
|
||||
uint32_t block_size,
|
||||
@@ -915,8 +816,7 @@ template <uint32_t num_frags_x,
|
||||
typename T,
|
||||
typename CacheT,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
bool IsFP8=false>
|
||||
__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
||||
uint32_t* q_smem_offset_r,
|
||||
smem_t* k_smem,
|
||||
@@ -960,27 +860,20 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
const int scale_col = (ky * 2 + fy) * 4;
|
||||
b_frag_dq_T[0] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_k_scale[0];
|
||||
}
|
||||
}
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
const int scale_col = (ky * 2 + fy) * 4;
|
||||
b_frag_dq_T[0] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_k_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_k_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_k_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_k_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_k_scale[scale_col + 3];
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4];
|
||||
b_frag_dq_T[b_i] *= cache_k_scale[0];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
@@ -1200,9 +1093,7 @@ template <uint32_t num_frags_x,
|
||||
uint32_t block_size,
|
||||
typename T,
|
||||
typename CacheT,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
bool is_scale_channel_wise = false, bool IsFP8=false>
|
||||
__device__ __forceinline__ void compute_sfm_v_c8(
|
||||
smem_t* v_smem,
|
||||
uint32_t* v_smem_offset_r,
|
||||
@@ -1244,28 +1135,16 @@ __device__ __forceinline__ void compute_sfm_v_c8(
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
}
|
||||
} else {
|
||||
const int scale_col = (kz * 2 + fz) * 4;
|
||||
b_frag_dq_T[0] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
|
||||
@@ -1292,9 +1171,7 @@ template <uint32_t num_frags_x,
|
||||
uint32_t block_size,
|
||||
typename T,
|
||||
typename CacheT,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
bool is_scale_channel_wise = false, bool IsFP8=false>
|
||||
__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
smem_t* v_smem,
|
||||
uint32_t* v_smem_offset_r,
|
||||
@@ -1338,28 +1215,16 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
|
||||
}
|
||||
} else {
|
||||
const int scale_col = (kz * 2 + fz) * 4;
|
||||
b_frag_dq_T[0] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
|
||||
|
@@ -103,7 +103,6 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -265,10 +264,9 @@ void CascadeAppendAttentionKernel(
|
||||
causal,
|
||||
is_decoder,
|
||||
enable_prefill,
|
||||
cache_quant_type_str,
|
||||
stream,
|
||||
out);
|
||||
} else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
|
||||
qkv,
|
||||
cache_k,
|
||||
@@ -301,7 +299,6 @@ void CascadeAppendAttentionKernel(
|
||||
causal,
|
||||
is_decoder,
|
||||
enable_prefill,
|
||||
cache_quant_type_str,
|
||||
stream,
|
||||
out);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
|
@@ -120,6 +120,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
float row_variance =
|
||||
max(warp_m2 / head_size, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
|
||||
if (hi < num_heads) { // q
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
@@ -128,7 +129,6 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
}
|
||||
} else { // k
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
}
|
||||
@@ -381,6 +381,142 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_neox_partial_rope_kernel(
|
||||
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
|
||||
const float* __restrict__ sin_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int rotary_dim,
|
||||
const int block_size,
|
||||
const uint32_t elem_cnt,
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadKVT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
|
||||
LoadT left_vec, right_vec;
|
||||
LoadBiasT left_bias_vec, right_bias_vec;
|
||||
LoadKVT left_cache_vec, right_cache_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int half_head_size = head_size / 2;
|
||||
const int half_rotary_dim = rotary_dim / 2;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
|
||||
const int64_t half_hidden_size = hidden_size / 2;
|
||||
// const int64_t offset = 2 * hidden_size;
|
||||
|
||||
for (int32_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int ori_bi = linear_index / half_hidden_size;
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
if (hi < num_heads && h_bias >= half_rotary_dim){
|
||||
continue;
|
||||
}
|
||||
if (seq_lens_encoder[ori_bi] > 0) continue;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
if (write_seq_id == 0) continue;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
uint32_t ori_idx_left =
|
||||
start_token_idx * hidden_size + hi * head_size + h_bias;
|
||||
uint32_t ori_idx_right = ori_idx_left + half_head_size;
|
||||
if (hi < num_heads){
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}else if (hi < num_heads + kv_num_heads){
|
||||
if (h_bias < half_rotary_dim){
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}else{
|
||||
ori_idx_left = ori_idx_left + half_rotary_dim;
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}
|
||||
}
|
||||
|
||||
Load<T, VecSize>(&qkv[ori_idx_left], &left_vec);
|
||||
Load<T, VecSize>(&qkv[ori_idx_right], &right_vec);
|
||||
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
// q k rope
|
||||
const uint32_t emb_idx = write_seq_id * half_rotary_dim + h_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
|
||||
if (h_bias < half_rotary_dim){
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
// rope
|
||||
float input_left = static_cast<float>(left_vec[i]);
|
||||
float input_right = static_cast<float>(right_vec[i]);
|
||||
if (hi < num_heads + kv_num_heads && h_bias < half_rotary_dim) {
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_bias_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
right_bias_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
} else {
|
||||
left_bias_vec[i] = static_cast<T>(input_left);
|
||||
right_bias_vec[i] = static_cast<T>(input_right);
|
||||
}
|
||||
}
|
||||
if (hi < num_heads) {
|
||||
// write q
|
||||
Store<T, VecSize>(left_bias_vec, &qkv_out[ori_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &qkv_out[ori_idx_right]);
|
||||
} else {
|
||||
// write k/v
|
||||
const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads;
|
||||
uint32_t tgt_idx_left =
|
||||
block_idx * kv_num_heads * block_size * head_size +
|
||||
kv_head_idx * block_size * head_size + block_offset * head_size +
|
||||
h_bias;
|
||||
uint32_t tgt_idx_right = tgt_idx_left + half_head_size;
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
if (h_bias < half_rotary_dim) {
|
||||
tgt_idx_right = tgt_idx_left + half_rotary_dim;
|
||||
}else{
|
||||
tgt_idx_left = tgt_idx_left + half_rotary_dim;
|
||||
tgt_idx_right = tgt_idx_left + half_rotary_dim;
|
||||
}
|
||||
Store<T, VecSize>(left_bias_vec, &key_cache[tgt_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &key_cache[tgt_idx_right]);
|
||||
} else {
|
||||
Store<T, VecSize>(left_bias_vec, &value_cache[tgt_idx_left]);
|
||||
Store<T, VecSize>(right_bias_vec, &value_cache[tgt_idx_right]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
@@ -629,294 +765,6 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=true>
|
||||
__global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
T* __restrict__ cache_k_scale,
|
||||
T* __restrict__ cache_v_scale,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d,
|
||||
const float rms_norm_eps) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
const int lane_id = tid % 32;
|
||||
const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + bid * max_blocks_per_seq;
|
||||
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
int cache_offset;
|
||||
if (head_idx < num_heads) {
|
||||
cache_offset = 0;
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
cache_offset = block_idx * kv_num_heads * block_size + (head_idx - num_heads) % kv_num_heads * block_size + block_offset;
|
||||
}
|
||||
T *cache_k_scale_now = cache_k_scale + cache_offset;
|
||||
T *cache_v_scale_now = cache_v_scale + cache_offset;
|
||||
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
|
||||
if (head_idx < num_heads) {
|
||||
// q
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadOutScaleT = AlignedVector<float, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
|
||||
LoadT src_vec;
|
||||
LoadBiasT out_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size;
|
||||
#pragma unroll
|
||||
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
|
||||
head_bias += 32 * VecSize) {
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(tmp1);
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(tmp2);
|
||||
}
|
||||
// qk norm
|
||||
if (q_norm_weight) {
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
LoadOutScaleT q_norm_vec;
|
||||
Load<float, VecSize>(&q_norm_weight[lane_id * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
|
||||
}
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
// k
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||
const uint32_t kv_head_idx = (head_idx - num_heads) % kv_num_heads;
|
||||
if (block_offset == 0) {
|
||||
// pad zero for this kv_head_idx for this block
|
||||
LoadPadKVT pad_cache_vec;
|
||||
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE;
|
||||
constexpr int num_token_each_time = 32 / num_vecs_per_head_dim;
|
||||
const uint32_t tgt_idx =
|
||||
(block_idx * kv_num_heads + kv_head_idx) * block_size * HeadDim +
|
||||
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
|
||||
for (int block_i = lane_id / num_vecs_per_head_dim;
|
||||
block_i < block_size;
|
||||
block_i += num_token_each_time) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
|
||||
&key_cache[tgt_idx + block_i * HeadDim]);
|
||||
}
|
||||
} else {
|
||||
const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE;
|
||||
const int num_token_each_time = 32 / num_vecs_per_head_dim;
|
||||
const uint32_t tgt_idx =
|
||||
(block_idx * kv_num_heads + kv_head_idx) * HeadDim * block_size +
|
||||
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
|
||||
for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim;
|
||||
block_i += num_token_each_time) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]);
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
constexpr int K_VEC_SIZE = 4;
|
||||
constexpr int HALF_K_VEC_SIZE = 2;
|
||||
using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>;
|
||||
using LoadKVT = AlignedVector<uint8_t, HALF_K_VEC_SIZE>;
|
||||
using LoadT = AlignedVector<T, HALF_K_VEC_SIZE>;
|
||||
using LoadBiasT = AlignedVector<T, HALF_K_VEC_SIZE>;
|
||||
using LoadOutScaleT = AlignedVector<float, HALF_K_VEC_SIZE>;
|
||||
using LoadEmbT = AlignedVector<float, 1>;
|
||||
LoadKVResT cache_vec;
|
||||
LoadT src_vec1, src_vec2;
|
||||
LoadBiasT out_vec1, out_vec2;
|
||||
LoadEmbT cos_emb_vec1, cos_emb_vec2;
|
||||
LoadEmbT sin_emb_vec1, sin_emb_vec2;
|
||||
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
|
||||
const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2;
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
|
||||
T scale = T(1.0f);
|
||||
const int k_head_idx = head_idx - num_heads;
|
||||
const int v_head_idx = head_idx - num_heads - kv_num_heads;
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
}
|
||||
|
||||
float input_left = static_cast<float>(src_vec1[0]);
|
||||
float input_right = static_cast<float>(src_vec1[1]);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec1[0] =
|
||||
static_cast<T>(tmp1);
|
||||
out_vec1[1] =
|
||||
static_cast<T>(tmp2);
|
||||
} else {
|
||||
out_vec1[0] = src_vec1[0];
|
||||
out_vec1[1] = src_vec1[1];
|
||||
}
|
||||
|
||||
// rope
|
||||
input_left = static_cast<float>(src_vec2[0]);
|
||||
input_right = static_cast<float>(src_vec2[1]);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec2[0] = static_cast<T>(tmp1);
|
||||
out_vec2[1] = static_cast<T>(tmp2);
|
||||
} else {
|
||||
out_vec2[0] = src_vec2[0];
|
||||
out_vec2[1] = src_vec2[1];
|
||||
}
|
||||
if (k_norm_weight) {
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
LoadOutScaleT k_norm_vec1, k_norm_vec2;
|
||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias], &k_norm_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias + 8], &k_norm_vec2);
|
||||
// qk norm
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
out_vec1[i] = static_cast<T>(static_cast<float>(out_vec1[i]) * row_inv_var * k_norm_vec1[i]);
|
||||
out_vec2[i] = static_cast<T>(static_cast<float>(out_vec2[i]) * row_inv_var * k_norm_vec2[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// reduce max, 1 head per warp
|
||||
T local_max = -INFINITY;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
local_max = __hmax(local_max, __habs(out_vec1[i]));
|
||||
local_max = __hmax(local_max, __habs(out_vec2[i]));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
|
||||
local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
|
||||
}
|
||||
|
||||
scale = __hdiv(448, local_max);
|
||||
|
||||
if (lane_id == 0) {
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
cache_k_scale_now[0] = __hdiv(1, scale);
|
||||
} else {
|
||||
cache_v_scale_now[0] = __hdiv(1, scale);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
cache_vec[i] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec1[i], max_bound, min_bound);
|
||||
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec2[i], max_bound, min_bound);
|
||||
}
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const int start_block_16 =
|
||||
block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8;
|
||||
const uint32_t tgt_cache_idx =
|
||||
block_idx * kv_num_heads * block_size * HeadDim +
|
||||
kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim +
|
||||
lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4;
|
||||
Store<uint8_t, K_VEC_SIZE>(cache_vec, &key_cache[tgt_cache_idx]);
|
||||
} else {
|
||||
const uint32_t base_tgt_cache_idx =
|
||||
block_idx * kv_num_heads * HeadDim * block_size +
|
||||
kv_head_idx * HeadDim * block_size +
|
||||
(lane_id / 4 * 16 + lane_id % 4 * 2) * block_size +
|
||||
block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32;
|
||||
const uint32_t tgt_cache_idx1 = base_tgt_cache_idx +
|
||||
block_offset % 8 / 2 * 4 // per 4
|
||||
+ block_offset % 16 / 8 * 2 // per 2
|
||||
+ block_offset % 2; // per 1
|
||||
const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size;
|
||||
const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16;
|
||||
const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size;
|
||||
value_cache[tgt_cache_idx1] = cache_vec[0];
|
||||
value_cache[tgt_cache_idx2] = cache_vec[1];
|
||||
value_cache[tgt_cache_idx3] = cache_vec[2];
|
||||
value_cache[tgt_cache_idx4] = cache_vec[3];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=false>
|
||||
__global__ void append_decode_cache_int8_rope_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
|
@@ -97,6 +97,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int rotary_dim,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const cudaStream_t& stream,
|
||||
@@ -137,7 +138,29 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
||||
if (rotary_dim < dim_head){
|
||||
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}else{
|
||||
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
@@ -157,6 +180,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
@@ -534,11 +558,20 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const float* cos_emb =
|
||||
rotary_embs ? rotary_embs.get().data<float>() : nullptr;
|
||||
const float* sin_emb;
|
||||
int rotary_dim = dim_head;
|
||||
if (rotary_embs) {
|
||||
sin_emb =
|
||||
use_neox_rotary_style
|
||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
|
||||
if(rotary_dim < dim_head){
|
||||
if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight || k_norm_weight|| cache_quant_type_str != "none"){
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"partial_rotary_factor < 1.0 only supports neox_rotary_style=True, qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and cache_quant_type_str is 'none'."));
|
||||
}
|
||||
sin_emb = rotary_embs.get().data<float>() + max_seq_len * rotary_dim / 2;
|
||||
}
|
||||
}
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
@@ -572,40 +605,9 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||
rms_norm_eps);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
dim3 grids(bsz, all_warps / num_warps);
|
||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
|
||||
q_norm_weight.get().data<float>(),
|
||||
k_norm_weight.get().data<float>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm just supports cache_quant_type none/block_wise_fp8");
|
||||
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||
}
|
||||
} else {
|
||||
if (cache_quant_type_str == "none") {
|
||||
@@ -630,6 +632,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
@@ -740,37 +743,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
dim3 grids(bsz, all_warps / num_warps);
|
||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
|
||||
nullptr,
|
||||
nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_decode_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
|
@@ -900,6 +900,74 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void GQANeoxVariableLengthPartialRotaryKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
const int *seq_lens_decoder,
|
||||
const float *qkv_out_scales,
|
||||
const T *qkv_biases,
|
||||
T *qkv_out,
|
||||
const int64_t elem_cnt,
|
||||
const int q_num_head,
|
||||
const int kv_num_head,
|
||||
const int seq_len,
|
||||
const int head_dim,
|
||||
const int rotary_dim,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
LoadT left_vec;
|
||||
LoadT right_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int rotary_dim_half = rotary_dim / 2;
|
||||
const int offset = (q_num_head + kv_num_head) * rotary_dim_half;
|
||||
for (int64_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
if (seq_lens && seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / rotary_dim_half;
|
||||
const int h_bias = bias % rotary_dim_half;
|
||||
|
||||
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
|
||||
const int emb_idx = ori_seq_id * rotary_dim_half + h_bias;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * head_dim * seq_len * 2 : emb_idx;
|
||||
const int base_idx_left =
|
||||
token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim +
|
||||
h_bias;
|
||||
const int base_idx_right = base_idx_left + rotary_dim_half;
|
||||
|
||||
Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
|
||||
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
const float input_left = static_cast<float>(left_vec[i]);
|
||||
const float input_right = static_cast<float>(right_vec[i]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
left_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
right_vec[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
Store<T, VecSize>(left_vec, &qkv_out[base_idx_left]);
|
||||
Store<T, VecSize>(right_vec, &qkv_out[base_idx_right]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void cache_kernel(
|
||||
const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads,
|
||||
@@ -1232,411 +1300,6 @@ __global__ void append_write_cache_kv_c8_qkv(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
uint32_t num_frags_y,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t HEAD_DIM,
|
||||
uint32_t BLOCK_SIZE,
|
||||
uint32_t NUM_WARPS,
|
||||
bool is_need_kv_quant,
|
||||
bool IsFP8 = true>
|
||||
__global__ void append_write_cache_kv_c8_qkv_dynamic(
|
||||
uint8_t *__restrict__ cache_k,
|
||||
uint8_t *__restrict__ cache_v,
|
||||
const T *__restrict__ qkv_input,
|
||||
T *__restrict__ cache_k_scales, // [block_num, num_heads, block_size]
|
||||
T *__restrict__ cache_v_scales, // [block_num, num_heads, block_size]
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids,
|
||||
const int *__restrict__ seq_lens_this_time,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ batch_id_per_token,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_tables,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
constexpr uint32_t pad_len = BLOCK_SIZE;
|
||||
const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z;
|
||||
const T cache_k_scale = cache_k_scales[kv_head_idx];
|
||||
const T cache_v_scale = cache_v_scales[kv_head_idx];
|
||||
const uint32_t tid = threadIdx.x, wid = threadIdx.y;
|
||||
const uint32_t batch_id = batch_ids[btid];
|
||||
const uint32_t tile_id = tile_ids[btid];
|
||||
const uint32_t seq_len_this_time = seq_lens_this_time[batch_id];
|
||||
if (seq_len_this_time <= 0) {
|
||||
return;
|
||||
}
|
||||
const int *block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + batch_id * max_blocks_per_seq;
|
||||
|
||||
const uint32_t num_rows_per_block =
|
||||
NUM_WARPS * num_frags_z * 16; // BLOCK_SIZE
|
||||
const uint32_t start_len = seq_lens_decoder[batch_id];
|
||||
const uint32_t bf_pad_len = start_len % pad_len;
|
||||
const uint32_t start_len_pad = start_len - bf_pad_len;
|
||||
const uint32_t end_len = start_len + seq_len_this_time;
|
||||
|
||||
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
|
||||
int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]);
|
||||
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
|
||||
|
||||
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
|
||||
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = HEAD_DIM;
|
||||
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
__shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
__shared__ T v_scale_smem[BLOCK_SIZE];
|
||||
if (tile_start >= start_len) {
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||
// pad zero for this kv_head_idx for this block
|
||||
LoadPadKVT pad_cache_vec;
|
||||
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
||||
// reset k
|
||||
constexpr int num_vecs_per_head_k = HEAD_DIM / KV_VEC_SIZE;
|
||||
constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k;
|
||||
uint32_t tgt_idx =
|
||||
(block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM +
|
||||
tid % num_vecs_per_head_k * KV_VEC_SIZE;
|
||||
for (int block_i = tid / num_vecs_per_head_k;
|
||||
block_i < BLOCK_SIZE;
|
||||
block_i += num_token_each_time_k) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
|
||||
&cache_k[tgt_idx + block_i * HEAD_DIM]);
|
||||
}
|
||||
|
||||
// reset v
|
||||
const int num_vecs_per_head_v = BLOCK_SIZE / KV_VEC_SIZE;
|
||||
const int num_token_each_time_v = 32 / num_vecs_per_head_v;
|
||||
tgt_idx =
|
||||
(block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE +
|
||||
tid % num_vecs_per_head_v * KV_VEC_SIZE;
|
||||
for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM;
|
||||
block_i += num_token_each_time_v) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE]);
|
||||
}
|
||||
}
|
||||
smem_t k_smem(k_smem_ori);
|
||||
smem_t v_smem(v_smem_ori);
|
||||
|
||||
uint32_t kv_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
wid * num_frags_z * 16 + tid / 8, tid % 8); // 4 * 8 per warp
|
||||
|
||||
/*
|
||||
0 | 1
|
||||
2 | 3
|
||||
*/
|
||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
||||
|
||||
constexpr uint32_t num_frags_v = num_frags_y / NUM_WARPS;
|
||||
/*
|
||||
0 | 2
|
||||
1 | 3
|
||||
*/
|
||||
uint32_t v_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
tid % 16, wid * num_frags_v * 2 + tid / 16);
|
||||
|
||||
// load kv gmem to smem
|
||||
const uint32_t real_start_token_idx = start_token_idx - bf_pad_len +
|
||||
tile_id * num_rows_per_block +
|
||||
wid * num_frags_z * 16 + tid / 8;
|
||||
uint32_t k_read_idx = real_start_token_idx * kv_batch_stride +
|
||||
(num_heads + kv_head_idx) * kv_h_stride +
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
uint32_t v_read_idx = real_start_token_idx * kv_batch_stride +
|
||||
(num_heads + kv_num_heads + kv_head_idx) * kv_h_stride +
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 4; ++j) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y / 4;
|
||||
++fy) { // (num_frags_y * 16) / (8 * num_elems_per_128b<T>())
|
||||
if (chunk_start >= start_len && chunk_start < end_len) {
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
kv_smem_offset_w, qkv_input + k_read_idx, chunk_start < end_len);
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
kv_smem_offset_w, qkv_input + v_read_idx, chunk_start < end_len);
|
||||
}
|
||||
kv_smem_offset_w =
|
||||
k_smem.advance_offset_by_column<8>(kv_smem_offset_w, fy);
|
||||
k_read_idx += 8 * num_elems_per_128b<T>();
|
||||
v_read_idx += 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
kv_smem_offset_w =
|
||||
k_smem.advance_offset_by_row<4, num_vecs_per_head>(kv_smem_offset_w) -
|
||||
2 * num_frags_y;
|
||||
chunk_start += 4;
|
||||
k_read_idx +=
|
||||
4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b<T>();
|
||||
v_read_idx +=
|
||||
4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
commit_group();
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// reduce scale
|
||||
// 16 rows per warp
|
||||
uint32_t kv_reduce_frag[4];
|
||||
T *kv_reduce_frag_T = reinterpret_cast<T*>(kv_reduce_frag);
|
||||
|
||||
T k_local_max_value[num_frags_z * 2];
|
||||
T v_local_max_value[num_frags_z * 2];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_frags_z * 2; i++) {
|
||||
k_local_max_value[i] = -INFINITY;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_frags_z * 2; i++) {
|
||||
v_local_max_value[i] = -INFINITY;
|
||||
}
|
||||
const int num_kv_heads = gridDim.z;
|
||||
const int scale_offset = block_id * num_kv_heads * BLOCK_SIZE + kv_head_idx * BLOCK_SIZE;
|
||||
T *cache_k_scale_now = cache_k_scales + scale_offset;
|
||||
T *cache_v_scale_now = cache_v_scales + scale_offset;
|
||||
// k scale
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
// reduce per thread, 4 threads each row
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
k_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), k_local_max_value[fz * 2]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
k_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), k_local_max_value[fz * 2 + 1]);
|
||||
}
|
||||
k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
|
||||
}
|
||||
// reduce per row
|
||||
for (int i = 0; i < 2; i++) {
|
||||
T local_max_value = __habs(k_local_max_value[fz * 2 + i]);
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2));
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1));
|
||||
// used for quant
|
||||
k_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value);
|
||||
}
|
||||
// store
|
||||
if (tid % 4 == 0) {
|
||||
const int offset_now = wid * num_frags_z * 16 + tid / 4;
|
||||
// used for dequant
|
||||
if (tile_start + offset_now >= start_len) {
|
||||
if (tile_start + offset_now < end_len) {
|
||||
cache_k_scale_now[offset_now] = __hdiv(1, k_local_max_value[fz * 2]);
|
||||
} else {
|
||||
cache_k_scale_now[offset_now] = 0;
|
||||
}
|
||||
}
|
||||
if (tile_start + offset_now + 8 >= start_len) {
|
||||
if (tile_start + offset_now + 8 < end_len) {
|
||||
cache_k_scale_now[offset_now + 8] = __hdiv(1, k_local_max_value[fz * 2 + 1]);
|
||||
} else {
|
||||
cache_k_scale_now[offset_now + 8] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1
|
||||
}
|
||||
// v scale
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
// reduce per thread, 4 threads each row
|
||||
v_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
v_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), v_local_max_value[fz * 2]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
v_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), v_local_max_value[fz * 2 + 1]);
|
||||
}
|
||||
k_smem_offset_r = v_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
|
||||
}
|
||||
// reduce per row
|
||||
for (int i = 0; i < 2; i++) {
|
||||
T local_max_value = __habs(v_local_max_value[fz * 2 + i]);
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2));
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1));
|
||||
v_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value);
|
||||
}
|
||||
// store
|
||||
if (tid % 4 == 0) {
|
||||
const int offset_now = wid * num_frags_z * 16 + tid / 4;
|
||||
// used for dequant
|
||||
if (tile_start + offset_now >= start_len) {
|
||||
if (tile_start + offset_now < end_len) {
|
||||
cache_v_scale_now[offset_now] = __hdiv(1, v_local_max_value[fz * 2]);
|
||||
v_scale_smem[offset_now] = v_local_max_value[fz * 2];
|
||||
} else {
|
||||
cache_v_scale_now[offset_now] = 0;
|
||||
v_scale_smem[offset_now] = 0;
|
||||
}
|
||||
}
|
||||
if (tile_start + offset_now + 8 >= start_len) {
|
||||
if (tile_start + offset_now + 8 < end_len) {
|
||||
cache_v_scale_now[offset_now + 8] = __hdiv(1, v_local_max_value[fz * 2 + 1]);
|
||||
v_scale_smem[offset_now + 8] = v_local_max_value[fz * 2 + 1];
|
||||
} else {
|
||||
cache_v_scale_now[offset_now + 8] = 0;
|
||||
v_scale_smem[offset_now + 8] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// mask, quant, store
|
||||
using LoadKVT = AlignedVector<uint8_t, 4>;
|
||||
LoadKVT cache_vec1;
|
||||
LoadKVT cache_vec2;
|
||||
|
||||
uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4;
|
||||
uint32_t kv_frag[4];
|
||||
const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t write_b_stride = HEAD_DIM;
|
||||
const uint32_t write_d_stride = BLOCK_SIZE;
|
||||
uint32_t k_write_idx = block_id * write_n_stride +
|
||||
kv_head_idx * write_h_stride +
|
||||
(wid * num_frags_z * 16 + tid / 4) * write_b_stride +
|
||||
tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
uint32_t k_write_idx_now_z = k_write_idx + fz * 16 * write_b_stride;
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
uint32_t k_write_idx_now = k_write_idx_now_z +
|
||||
fy % 2 * 8 * write_b_stride +
|
||||
fy / 2 * 32; // + fy % 2 * 16;
|
||||
// load
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
|
||||
// quant
|
||||
T *k_frag_T = reinterpret_cast<T *>(kv_frag);
|
||||
if (bf_pad_len != 0) {
|
||||
Load<uint8_t, 4>(cache_k + k_write_idx_now, &cache_vec1);
|
||||
Load<uint8_t, 4>(cache_k + k_write_idx_now + 16, &cache_vec2);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t v_id = 0; v_id < 8; ++v_id) {
|
||||
uint8_t uint_quant_value;
|
||||
if (chunk_start_k + (v_id / 4) * 8 >= start_len &&
|
||||
chunk_start_k + (v_id / 4) * 8 < end_len) {
|
||||
uint_quant_value = QuantToC8<T, is_need_kv_quant, IsFP8>(k_local_max_value[fz * 2 + v_id / 4], k_frag_T[v_id], 127.0f, -127.0f);
|
||||
} else {
|
||||
uint_quant_value = 0;
|
||||
}
|
||||
if (bf_pad_len != 0) {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] |= uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id % 4] |= uint_quant_value;
|
||||
}
|
||||
} else {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] = uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id - 4] = uint_quant_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
// store
|
||||
Store<uint8_t, 4>(cache_vec1, cache_k + k_write_idx_now);
|
||||
Store<uint8_t, 4>(cache_vec2, cache_k + k_write_idx_now + 16);
|
||||
k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
|
||||
}
|
||||
k_smem_offset_r =
|
||||
k_smem.advance_offset_by_row<16, num_vecs_per_head>(k_smem_offset_r) -
|
||||
2 * num_frags_y;
|
||||
chunk_start_k += 16;
|
||||
}
|
||||
|
||||
uint32_t chunk_start_v = tile_start + tid % 4 * 2;
|
||||
uint32_t v_write_idx = block_id * write_n_stride +
|
||||
kv_head_idx * write_h_stride +
|
||||
(wid * num_frags_v * 16 + tid / 4) * write_d_stride +
|
||||
tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit
|
||||
const uint32_t num_frags_z_v = num_frags_z * NUM_WARPS;
|
||||
T v_scales[num_frags_z_v * 4];
|
||||
for (int v_i = 0; v_i < num_frags_z_v; v_i++) {
|
||||
const int offset = v_i * 16;
|
||||
const int t_offset = tid % 4 * 2;
|
||||
v_scales[v_i * 4] = v_scale_smem[offset + t_offset];
|
||||
v_scales[v_i * 4 + 1] = v_scale_smem[offset + t_offset + 1];
|
||||
v_scales[v_i * 4 + 2] = v_scale_smem[offset + t_offset + 8];
|
||||
v_scales[v_i * 4 + 3] = v_scale_smem[offset + t_offset + 9];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_v; ++fy) {
|
||||
uint32_t v_write_idx_now_v = v_write_idx + fy * 16 * write_d_stride;
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z_v; ++fz) {
|
||||
uint32_t v_write_idx_now = v_write_idx_now_v +
|
||||
fz % 2 * 8 * write_d_stride +
|
||||
fz / 2 * 32; // + fz % 2 * 16;
|
||||
// load
|
||||
v_smem.ldmatrix_m8n8x4_trans(v_smem_offset_r, kv_frag);
|
||||
// quant
|
||||
T *v_frag_T = reinterpret_cast<T *>(kv_frag);
|
||||
if (bf_pad_len != 0) {
|
||||
Load<uint8_t, 4>(cache_v + v_write_idx_now, &cache_vec1);
|
||||
Load<uint8_t, 4>(cache_v + v_write_idx_now + 16, &cache_vec2);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t v_id = 0; v_id < 8; ++v_id) {
|
||||
uint8_t uint_quant_value;
|
||||
if (chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 >= start_len &&
|
||||
chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 < end_len) {
|
||||
uint_quant_value = QuantToC8<T, is_need_kv_quant, IsFP8>(v_scales[fz * 4 + v_id % 4], v_frag_T[v_id], 127.0f, -127.0f);
|
||||
// store now
|
||||
} else {
|
||||
uint_quant_value = 0;
|
||||
}
|
||||
if (bf_pad_len != 0) {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] |= uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id % 4] |= uint_quant_value;
|
||||
}
|
||||
} else {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] = uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id % 4] = uint_quant_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
// store
|
||||
Store<uint8_t, 4>(cache_vec1, cache_v + v_write_idx_now);
|
||||
Store<uint8_t, 4>(cache_vec2, cache_v + v_write_idx_now + 16);
|
||||
chunk_start_v += 16;
|
||||
v_smem_offset_r =
|
||||
k_smem.advance_offset_by_row<16, num_vecs_per_head>(v_smem_offset_r);
|
||||
}
|
||||
v_smem_offset_r = k_smem.advance_offset_by_column<2>(
|
||||
v_smem_offset_r, wid * num_frags_v + fy) -
|
||||
16 * num_frags_z_v * num_vecs_per_head;
|
||||
chunk_start_v -= 16 * num_frags_z_v;
|
||||
}
|
||||
}
|
||||
|
||||
// Write Cache KV in Append
|
||||
template <typename T,
|
||||
uint32_t num_frags_y,
|
||||
@@ -2160,6 +1823,7 @@ void gqa_rotary_qk_variable(
|
||||
const int seq_len,
|
||||
const int input_output_len,
|
||||
const int dim_head,
|
||||
const int rotary_dim,
|
||||
const cudaStream_t &stream,
|
||||
bool use_neox_style = false,
|
||||
bool rope_3d = false) {
|
||||
@@ -2240,7 +1904,38 @@ void gqa_rotary_qk_variable(
|
||||
dim_head,
|
||||
rope_3d);
|
||||
} else {
|
||||
GQANeoxVariableLengthRotaryKernel<T, PackSize>
|
||||
if (rotary_dim < dim_head){
|
||||
PD_CHECK((rotary_dim / 2) % PackSize == 0);
|
||||
elem_nums =
|
||||
qkv_out_scales
|
||||
? token_num * (num_heads + 2 * kv_num_heads) * rotary_dim
|
||||
: token_num * (num_heads + kv_num_heads) * rotary_dim; // for all q k v
|
||||
if (use_neox_style) {
|
||||
elem_nums /= 2;
|
||||
}
|
||||
const int pack_num_new = elem_nums / PackSize;
|
||||
GetNumBlocks<128>(pack_num_new, &grid_size);
|
||||
GQANeoxVariableLengthPartialRotaryKernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
rotary_emb + input_output_len * rotary_dim / 2,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_decoder,
|
||||
qkv_out_scales,
|
||||
qkv_bias,
|
||||
qkv_out,
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
rope_3d);
|
||||
}else{
|
||||
GQANeoxVariableLengthRotaryKernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<const T *>(qkv_input),
|
||||
cos_emb,
|
||||
@@ -2258,6 +1953,7 @@ void gqa_rotary_qk_variable(
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2411,11 +2107,10 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
int num_blocks_x_cpu,
|
||||
int max_seq_len,
|
||||
bool is_scale_channel_wise,
|
||||
const std::string& cache_quant_type,
|
||||
const bool is_fp8,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *cache_k_out,
|
||||
paddle::Tensor *cache_v_out) {
|
||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
||||
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
||||
auto num_tokens = meta_data.token_nums;
|
||||
auto num_heads = meta_data.q_num_heads;
|
||||
@@ -2433,77 +2128,49 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
dim3 blocks(32, num_warps);
|
||||
|
||||
const uint32_t smem_size = (BLOCK_SIZE * HEAD_DIM) * sizeof(T) * 2;
|
||||
if (cache_quant_type != "block_wise_fp8") {
|
||||
auto kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, false>;
|
||||
if (cache_quant_type == "cache_fp8") {
|
||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, true>;
|
||||
}
|
||||
if (is_scale_channel_wise) {
|
||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
false>;
|
||||
}
|
||||
cudaFuncSetAttribute(
|
||||
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
|
||||
cache_v_out->data<uint8_t>(),
|
||||
qkv.data<T>(),
|
||||
cache_k_scale.data<T>(),
|
||||
cache_v_scale.data<T>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads);
|
||||
} else {
|
||||
auto kernel_fn = append_write_cache_kv_c8_qkv_dynamic<NV_TYPE,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, true>;
|
||||
cudaFuncSetAttribute(
|
||||
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
|
||||
cache_v_out->data<uint8_t>(),
|
||||
reinterpret_cast<const NV_TYPE*>(qkv.data<T>()),
|
||||
const_cast<NV_TYPE*>(reinterpret_cast<const NV_TYPE*>(cache_k_scale.data<T>())),
|
||||
const_cast<NV_TYPE*>(reinterpret_cast<const NV_TYPE*>(cache_v_scale.data<T>())),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads);
|
||||
auto kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, false>;
|
||||
if (is_fp8) {
|
||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, true>;
|
||||
}
|
||||
if (is_scale_channel_wise) {
|
||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
false>;
|
||||
}
|
||||
cudaFuncSetAttribute(
|
||||
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
|
||||
cache_v_out->data<uint8_t>(),
|
||||
qkv.data<T>(),
|
||||
cache_k_scale.data<T>(),
|
||||
cache_v_scale.data<T>(),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads);
|
||||
}
|
||||
|
||||
template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>
|
||||
|
@@ -55,9 +55,19 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
auto head_dim = meta_data.head_dims;
|
||||
bool is_scale_channel_wise = false;
|
||||
int rotary_dim = head_dim;
|
||||
if (cache_k_scale && cache_k_scale.get().dims()[0] == head_dim * kv_num_heads) {
|
||||
is_scale_channel_wise = true;
|
||||
}
|
||||
if (rotary_embs){
|
||||
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
|
||||
if(rotary_dim < head_dim){
|
||||
if (!use_neox_style || q_norm_weight || k_norm_weight || num_heads == kv_num_heads || is_scale_channel_wise){
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"partial_rotary_factor < 1.0 only supports use_neox_rotary_style=True, q_norm_weight/k_norm_weight) is None, GQA and is_scale_channel_wise=false."));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) {
|
||||
@@ -125,6 +135,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
max_seq_len,
|
||||
rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2],
|
||||
head_dim,
|
||||
rotary_dim,
|
||||
stream,
|
||||
use_neox_style,
|
||||
rope_3d);
|
||||
@@ -167,7 +178,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
stream,
|
||||
key_cache_out,
|
||||
value_cache_out);
|
||||
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
|
||||
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8") {
|
||||
DISPATCH_HEAD_DIM(
|
||||
head_dim, HEAD_DIM, {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {
|
||||
CascadeAppendWriteCacheKVC8QKV<T, HEAD_DIM, BLOCK_SIZE>(
|
||||
@@ -187,7 +198,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
is_scale_channel_wise,
|
||||
cache_quant_type_str,
|
||||
cache_quant_type_str == "cache_fp8",
|
||||
stream,
|
||||
key_cache_out,
|
||||
value_cache_out);
|
||||
|
@@ -1000,7 +1000,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8") {
|
||||
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
|
||||
CascadeAppendWriteCacheKVC8QKV<data_t, 128, 64>(
|
||||
meta_data,
|
||||
*const_cast<paddle::Tensor*>(&key_cache),
|
||||
@@ -1018,7 +1018,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
kv_num_blocks_data,
|
||||
max_seq_len,
|
||||
false, // is_scale_channel_wise
|
||||
cache_quant_type,
|
||||
cache_quant_type == "cache_fp8", // is_fp8
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
|
@@ -18,168 +18,6 @@
|
||||
#include "mma_tensor_op.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, int VecSize = 1, typename InT = T>
|
||||
__global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
|
||||
// head_size]
|
||||
T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ q_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
const float*
|
||||
qkv_out_scales, // [(num_heads + 2 * gqa_group_size) * head_size]
|
||||
const T* qkv_biases, // [num_head + 2 * gqa_group_size, dim_head]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int output_inner_dim,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
LoadInT src_vec;
|
||||
LoadFloat scale_vec;
|
||||
LoadT bias_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
LoadFloat tmp_vec;
|
||||
LoadFloat q_norm_vec;
|
||||
LoadFloat k_norm_vec;
|
||||
|
||||
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||
int64_t all_head_dim = elem_cnt / head_size;
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size;
|
||||
const int half_head_size = head_size / 2;
|
||||
for (int global_hi = global_warp_idx; global_hi < all_head_dim; global_hi += all_warp_num) {
|
||||
int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize;
|
||||
const int token_id = linear_index / hidden_size;
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
if (block_idx < 0) {
|
||||
printf(
|
||||
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
|
||||
"%d %d %d %d\n",
|
||||
block_idx,
|
||||
write_seq_id,
|
||||
ori_bi,
|
||||
seq_lens_decoder[ori_bi],
|
||||
token_id,
|
||||
cu_seqlens_q[ori_bi]);
|
||||
}
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
const int write_q_idx =
|
||||
token_id * output_inner_dim * head_size + hi * head_size + h_bias;
|
||||
|
||||
const int bias_idx = hi * head_size + h_bias;
|
||||
Load<InT, VecSize>(&qkv[linear_index], &src_vec);
|
||||
if (qkv_biases) {
|
||||
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
|
||||
}
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &scale_vec);
|
||||
}
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
if (qkv_out_scales) {
|
||||
input_left *= scale_vec[2 * i];
|
||||
input_right *= scale_vec[2 * i + 1];
|
||||
}
|
||||
if (qkv_biases) {
|
||||
input_left = input_left + static_cast<float>(bias_vec[2 * i]);
|
||||
input_right = input_right + static_cast<float>(bias_vec[2 * i + 1]);
|
||||
}
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
tmp_vec[2 * i] = tmp1;
|
||||
tmp_vec[2 * i + 1] = tmp2;
|
||||
} else {
|
||||
bias_vec[2 * i] = static_cast<T>(input_left);
|
||||
bias_vec[2 * i + 1] = static_cast<T>(input_right);
|
||||
}
|
||||
}
|
||||
if (hi < (num_heads + gqa_group_size)) {
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / head_size, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
if (hi < num_heads) {
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
} else {
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (hi < num_heads) {
|
||||
// write q
|
||||
Store<T, VecSize>(bias_vec, &q_out[write_q_idx]);
|
||||
} else {
|
||||
// write k/v
|
||||
const int kv_head_idx = (hi - num_heads) % gqa_group_size;
|
||||
const int tgt_idx = (block_idx * gqa_group_size * block_size * head_size +
|
||||
kv_head_idx * block_size * head_size +
|
||||
block_offset * head_size + h_bias);
|
||||
// write
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
Store<T, VecSize>(bias_vec, &key_cache[tgt_idx]);
|
||||
} else {
|
||||
Store<T, VecSize>(bias_vec, &value_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int VecSize = 4, int HeadDim = 128>
|
||||
__global__ void append_clear_cache_int8_block(
|
||||
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
|
||||
@@ -355,8 +193,7 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
const int gqa_group_size) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
@@ -416,9 +253,8 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
@@ -490,8 +326,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
const int gqa_group_size) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
@@ -555,9 +390,8 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * head_size + h_bias;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2: emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
@@ -642,8 +476,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
const int gqa_group_size) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -689,9 +522,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
}
|
||||
@@ -751,11 +583,10 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
T scale;
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||
} else {
|
||||
scale = __ldg(&cache_v_scales[kv_head_idx]);
|
||||
@@ -877,8 +708,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
const int gqa_group_size) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -927,9 +757,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx_left],
|
||||
&left_out_scale_vec);
|
||||
@@ -1024,11 +853,10 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
|
||||
T scale;
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
|
||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
@@ -1260,8 +1088,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
const int gqa_group_size) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -1318,9 +1145,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
@@ -1409,11 +1235,10 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// &out_scale_vec2);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx], &scale_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx + 8], &scale_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
|
||||
@@ -1606,8 +1431,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
const int gqa_group_size) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -1757,11 +1581,10 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
&right_out_scale_vec2);
|
||||
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx],
|
||||
&left_scale_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx + 8],
|
||||
|
@@ -15,78 +15,6 @@
|
||||
#include "speculate_write_cache_with_rope_kernel.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
T* key_cache,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
const float* qkv_out_scales,
|
||||
const T* qkv_biases,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const bool rope_3d) {
|
||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||
const uint32_t elem_nums =
|
||||
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||
: token_num * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||
constexpr int HEAD_DIM = 128;
|
||||
|
||||
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
if (use_neox_style) {
|
||||
PD_THROW(
|
||||
"append_speculate_cache_rope_qk_norm not support neox rope yet");
|
||||
} else {
|
||||
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
|
||||
append_speculate_cache_T_rope_qk_norm_kernel<T, PackSize>
|
||||
<<<grid_size, block_dim, 0, stream>>>(qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales,
|
||||
qkv_biases,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
output_inner_dim,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
|
||||
// rope + write
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
@@ -111,8 +39,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
const bool use_neox_style) {
|
||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||
|
||||
const uint32_t elem_nums =
|
||||
@@ -146,8 +73,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
} else {
|
||||
append_speculate_cache_rope_kernel<T, PackSize>
|
||||
<<<grid_size, threads_per_block, 0, stream>>>(
|
||||
@@ -170,8 +96,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -200,8 +125,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
const bool use_neox_style) {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
@@ -243,8 +167,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
} else {
|
||||
append_speculate_cache_int8_rope_kernel<T, 4, 0, 128, QKV_TYPE, IsFP8>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
@@ -268,8 +191,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,8 +222,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
const bool use_neox_style) {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
@@ -345,8 +266,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
} else {
|
||||
append_speculate_cache_int4_rope_kernel<T, 4>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
@@ -372,8 +292,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
kv_num_heads);
|
||||
}
|
||||
}
|
||||
template <typename T, typename QKV_TYPE>
|
||||
@@ -394,15 +313,11 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
paddle::Tensor* value_cache_out) {
|
||||
typedef cascade_attn_type_traits<T> traits_;
|
||||
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
|
||||
typedef typename traits_::type DataType_;
|
||||
@@ -427,185 +342,142 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||
}
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope_qk_norm(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
reinterpret_cast<const float*>(q_norm_weight.get().data<float>()),
|
||||
reinterpret_cast<const float*>(k_norm_weight.get().data<float>()),
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||
}
|
||||
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
append_speculate_cache_int8_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_speculate_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
append_speculate_cache_int8_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_speculate_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, "
|
||||
"cache_int4_zp]");
|
||||
}
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, "
|
||||
"cache_int4_zp]");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -628,15 +500,11 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
||||
template void
|
||||
SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
@@ -658,15 +526,11 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
||||
template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -687,15 +551,11 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
||||
|
||||
template void
|
||||
@@ -718,12 +578,8 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
@@ -35,12 +35,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
paddle::Tensor* value_cache_out);
|
||||
|
@@ -56,7 +56,6 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -104,6 +103,5 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -99,6 +98,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -101,6 +100,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -101,6 +100,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -100,6 +99,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,7 +54,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -100,6 +99,5 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -441,15 +441,6 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
PD_THROW("not support the group_size", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \
|
||||
if (is_dynamic_cfp8) { \
|
||||
constexpr bool IsDynamicC8 = true; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
constexpr bool IsDynamicC8 = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
|
@@ -378,11 +378,9 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::optional<paddle::Tensor> &draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
|
||||
const int block_size,
|
||||
const int max_draft_tokens);
|
||||
const int block_size);
|
||||
|
||||
|
||||
|
||||
paddle::Tensor
|
||||
GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor,
|
||||
@@ -566,6 +564,7 @@ std::vector<paddle::Tensor> NoauxTc(
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
bool renormalize,
|
||||
float routed_scaling_factor);
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
@@ -709,22 +708,6 @@ void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
|
||||
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& seq_lens_decoder);
|
||||
|
||||
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &step_draft_tokens,
|
||||
const paddle::Tensor &step_seq_lens_this_time,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const int block_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
void NgramMatch(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &input_ids_len,
|
||||
const paddle::Tensor &pre_ids,
|
||||
@@ -768,7 +751,6 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
@@ -782,8 +764,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_draft_token,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1);
|
||||
const bool splitwise_prefill);
|
||||
|
||||
|
||||
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
@@ -1248,8 +1229,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function");
|
||||
|
||||
m.def("speculate_schedule_cache",&SpeculateScheduleCache, "SpeculateScheduleCache function");
|
||||
|
||||
m.def("ngram_match", &NgramMatch, "ngram_match function");
|
||||
|
||||
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");
|
||||
|
@@ -39,6 +39,9 @@ void GetOutputTopK(const paddle::Tensor& x,
|
||||
int k,
|
||||
int64_t rank_id,
|
||||
bool wait_flag) {
|
||||
if (rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
static struct msgdata msg_rcv;
|
||||
int msg_queue_id = 1;
|
||||
|
@@ -151,6 +151,34 @@ inline int GetGPUComputeCapability(int id) {
|
||||
|
||||
#endif
|
||||
|
||||
#ifndef FP8_E4M3_MAX
|
||||
#define FP8_E4M3_MAX 448.0
|
||||
#endif
|
||||
|
||||
#ifndef DISPATCH_FLOAT_FP6_DTYPE
|
||||
#define DISPATCH_FLOAT_FP6_DTYPE(pd_dtype, c_type, ...) \
|
||||
switch (pd_dtype) { \
|
||||
case phi::DataType::FLOAT32: { \
|
||||
using c_type = float; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case phi::DataType::BFLOAT16: { \
|
||||
using c_type = phi::dtype::bfloat16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case phi::DataType::FLOAT16: { \
|
||||
using c_type = phi::dtype::float16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16]."); \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1)
|
||||
return num;
|
||||
@@ -563,3 +591,28 @@ inline int GetSMVersion() {
|
||||
return sm_version;
|
||||
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float warpReduceMax(float value) {
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 16));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 8));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 4));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 2));
|
||||
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 1));
|
||||
return value;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float blockReduceMax(float value) {
|
||||
static __shared__ float warpLevelMaxs[WARP_SIZE];
|
||||
const int laneId = threadIdx.x % WARP_SIZE;
|
||||
const int warpId = threadIdx.x / WARP_SIZE;
|
||||
|
||||
value = warpReduceMax(value);
|
||||
|
||||
if (laneId == 0) warpLevelMaxs[warpId] = value;
|
||||
__syncthreads();
|
||||
|
||||
value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
|
||||
if (warpId == 0) value = warpReduceMax(value);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
@@ -30,10 +30,12 @@ paddle::Tensor mm(paddle::Tensor const& A, paddle::Tensor const& B,
|
||||
std::optional<paddle::Tensor> const& maybe_token_scales,
|
||||
std::string maybe_schedule) {
|
||||
machete::ScalarType const b_type = machete::ScalarType::from_id(b_type_id);
|
||||
std::optional<int64_t> maybe_group_size_opt;
|
||||
std::optional<int64_t> maybe_group_size_opt = std::optional<int64_t>(maybe_group_size);
|
||||
std::optional<std::string> maybe_schedule_opt;
|
||||
if (maybe_schedule == "") {
|
||||
maybe_schedule_opt = std::nullopt;
|
||||
} else {
|
||||
maybe_schedule_opt = std::optional<std::string>(maybe_schedule);
|
||||
}
|
||||
return machete::mm_dispatch({.A = A,
|
||||
.B = B,
|
||||
@@ -63,6 +65,8 @@ std::vector<paddle::Tensor> MacheteMMKernel(
|
||||
paddle::DataType maybe_out_type;
|
||||
if (b_type_str == "uint4b8") {
|
||||
b_type_id = machete::kU4B8.id();
|
||||
} else if (b_type_str == "uint8b128") {
|
||||
b_type_id = machete::kU8B128.id();
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||
}
|
||||
|
@@ -51,6 +51,8 @@ std::vector<paddle::Tensor> MachetePrepackBKernel(
|
||||
|
||||
if (b_type_str == "uint4b8") {
|
||||
b_type_id = machete::kU4B8.id();
|
||||
} else if (b_type_str == "uint8b128") {
|
||||
b_type_id = machete::kU8B128.id();
|
||||
} else {
|
||||
PADDLE_ENFORCE(false, "b_type_str not supported!");
|
||||
}
|
||||
|
@@ -33,11 +33,6 @@
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 3: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 3; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 6: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 6; \
|
||||
__VA_ARGS__ \
|
||||
@@ -453,71 +448,137 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
|
||||
auto place = input.place();
|
||||
const int gridx = min(132 * 8, num_rows);
|
||||
if (moe_quant_type == "w4a8") {
|
||||
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
|
||||
permute_x_kernel<data_t, int8_t, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<int8_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);)
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_kernel<data_t, int8_t, 8><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<int8_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);
|
||||
} else if (num_experts_per_rank == 16) {
|
||||
permute_x_kernel<data_t, int8_t, 16><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<int8_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);
|
||||
}
|
||||
} else if (moe_quant_type == "w4afp8") {
|
||||
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
|
||||
permute_x_kernel<data_t, data_t_fp8, NUM_EXPERTS_PER_RANK, 512><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t_fp8>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
448.0f,
|
||||
-448.0f
|
||||
);)
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_kernel<data_t, data_t_fp8, 8, 512><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t_fp8>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
448.0f,
|
||||
-448.0f
|
||||
);
|
||||
} else if (num_experts_per_rank == 16) {
|
||||
permute_x_kernel<data_t, data_t_fp8, 16, 512><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t_fp8>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
448.0f,
|
||||
-448.0f
|
||||
);
|
||||
}
|
||||
} else {
|
||||
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
|
||||
permute_x_kernel<data_t, data_t, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);)
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_kernel<data_t, data_t, 8><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);
|
||||
} else if (num_experts_per_rank == 16) {
|
||||
permute_x_kernel<data_t, data_t, 16><<<gridx, 512, 0, stream>>>(
|
||||
input.data<data_t>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
hidden_size,
|
||||
permute_input->data<data_t>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
expert_idx_per_token->data<int64_t>(),
|
||||
127.0,
|
||||
-127.0
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -26,6 +26,7 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
||||
int n_group,
|
||||
int topk_group,
|
||||
int topk,
|
||||
bool renormalize,
|
||||
float routed_scaling_factor) {
|
||||
auto input_shape = scores_with_bias.shape();
|
||||
PD_CHECK(input_shape.size() == 2);
|
||||
@@ -48,6 +49,7 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
|
||||
@@ -76,6 +78,7 @@ PD_BUILD_STATIC_OP(noaux_tc)
|
||||
.Attrs({"n_group: int",
|
||||
"topk_group: int",
|
||||
"topk:int",
|
||||
"renormalize: bool",
|
||||
"routed_scaling_factor: float"})
|
||||
.SetKernelFn(PD_KERNEL(NoauxTc))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcInferShape))
|
||||
|
@@ -25,6 +25,23 @@ constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||
constexpr int32_t BLOCK_SIZE = 512;
|
||||
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
__device__ inline T_OUT cuda_cast(T_IN val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T neg_inf() {
|
||||
// cuda::std::numeric_limits<T>::infinity() returns `0` for [T=bf16 or fp16]
|
||||
// so we need to cast from fp32
|
||||
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
|
||||
}
|
||||
|
||||
namespace warp_topk {
|
||||
|
||||
template <int size, typename T>
|
||||
@@ -41,10 +58,21 @@ constexpr __host__ __device__ bool isPowerOf2(T v) {
|
||||
}
|
||||
|
||||
template <bool greater, typename T>
|
||||
__device__ bool is_better_than(T val, T baseline) {
|
||||
__forceinline__ __device__ bool is_better_than(T val, T baseline) {
|
||||
return (val > baseline && greater) || (val < baseline && !greater);
|
||||
}
|
||||
|
||||
template <bool greater, typename T, typename idxT>
|
||||
__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index,
|
||||
idxT baseline_index) {
|
||||
bool res = (val > baseline && greater) || (val < baseline && !greater);
|
||||
if (val == baseline) {
|
||||
res = (index < baseline_index && greater) ||
|
||||
(index < baseline_index && !greater);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T, typename idxT>
|
||||
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
||||
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
|
||||
@@ -53,7 +81,8 @@ int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
||||
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
|
||||
}
|
||||
|
||||
template <int size, bool ascending, typename T, typename idxT>
|
||||
template <int size, bool ascending, bool reverse, typename T, typename idxT,
|
||||
bool is_stable>
|
||||
struct BitonicMerge {
|
||||
// input should be a bitonic sequence, and sort it to be a monotonic sequence
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
@@ -67,7 +96,15 @@ struct BitonicMerge {
|
||||
int const other_i = i + stride;
|
||||
T& val = val_arr[i];
|
||||
T& other_val = val_arr[other_i];
|
||||
if ((val > other_val && ascending) || (val < other_val && !ascending)) {
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better = is_better_than<ascending>(val, other_val, idx_arr[i],
|
||||
idx_arr[other_i]);
|
||||
} else {
|
||||
is_better = is_better_than<ascending>(val, other_val);
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
T tmp = val;
|
||||
val = other_val;
|
||||
other_val = tmp;
|
||||
@@ -78,13 +115,14 @@ struct BitonicMerge {
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
BitonicMerge<size / 2, ascending, T, idxT>::merge(val_arr + arr_len / 2,
|
||||
idx_arr + arr_len / 2);
|
||||
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
|
||||
val_arr, idx_arr);
|
||||
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
|
||||
val_arr + arr_len / 2, idx_arr + arr_len / 2);
|
||||
}
|
||||
};
|
||||
|
||||
template <int size, bool ascending, typename T, typename idxT>
|
||||
template <int size, bool ascending, typename T, typename idxT, bool is_stable>
|
||||
struct BitonicSort {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
@@ -92,15 +130,16 @@ struct BitonicSort {
|
||||
static_assert(size >= 2 * WARP_SIZE);
|
||||
constexpr int arr_len = size / WARP_SIZE;
|
||||
|
||||
BitonicSort<size / 2, true, T, idxT>::sort(val_arr, idx_arr);
|
||||
BitonicSort<size / 2, false, T, idxT>::sort(val_arr + arr_len / 2,
|
||||
idx_arr + arr_len / 2);
|
||||
BitonicMerge<size, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
BitonicSort<size / 2, true, T, idxT, is_stable>::sort(val_arr, idx_arr);
|
||||
BitonicSort<size / 2, false, T, idxT, is_stable>::sort(
|
||||
val_arr + arr_len / 2, idx_arr + arr_len / 2);
|
||||
BitonicMerge<size, ascending, ascending, T, idxT, is_stable>::merge(
|
||||
val_arr, idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, typename T, typename idxT>
|
||||
struct BitonicSort<32, ascending, T, idxT> {
|
||||
template <bool ascending, typename T, typename idxT, bool is_stable>
|
||||
struct BitonicSort<32, ascending, T, idxT, is_stable> {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
@@ -114,19 +153,37 @@ struct BitonicSort<32, ascending, T, idxT> {
|
||||
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);
|
||||
if (*val_arr != other && (*val_arr > other) != (reverse != is_second)) {
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
if constexpr (ascending) {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr < other_idx))) !=
|
||||
(reverse != is_second);
|
||||
} else {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr > other_idx))) !=
|
||||
(reverse != is_second);
|
||||
}
|
||||
} else {
|
||||
is_better = (*val_arr != other &&
|
||||
(*val_arr > other) != (reverse != is_second));
|
||||
}
|
||||
if (is_better) {
|
||||
*val_arr = other;
|
||||
*idx_arr = other_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<32, ascending, T, idxT>::merge(val_arr, idx_arr);
|
||||
BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr,
|
||||
idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, typename T, typename idxT>
|
||||
struct BitonicMerge<32, ascending, T, idxT> {
|
||||
template <bool ascending, bool reverse, typename T, typename idxT,
|
||||
bool is_stable>
|
||||
struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> {
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
@@ -136,7 +193,24 @@ struct BitonicMerge<32, ascending, T, idxT> {
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
|
||||
idxT& idx = *idx_arr;
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);
|
||||
if (val != other && ((val > other) == (ascending != is_second))) {
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
if constexpr (ascending) {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr < other_idx))) ==
|
||||
(reverse != is_second); // for min
|
||||
} else {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr > other_idx))) ==
|
||||
(reverse != is_second); // for max
|
||||
}
|
||||
} else {
|
||||
is_better =
|
||||
(val != other && ((val > other) == (ascending != is_second)));
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
val = other;
|
||||
idx = other_idx;
|
||||
}
|
||||
@@ -144,34 +218,42 @@ struct BitonicMerge<32, ascending, T, idxT> {
|
||||
}
|
||||
};
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT>
|
||||
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
|
||||
class WarpSort {
|
||||
public:
|
||||
public:
|
||||
__device__ WarpSort(idxT k, T dummy)
|
||||
: lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
|
||||
static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));
|
||||
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
val_arr_[i] = dummy_;
|
||||
idx_arr_[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// load and merge k sorted values
|
||||
__device__ void load_sorted(T const* __restrict__ in,
|
||||
idxT const* __restrict__ in_idx,
|
||||
idxT start) {
|
||||
idxT const* __restrict__ in_idx, idxT start) {
|
||||
idxT idx = start + WARP_SIZE - 1 - lane_;
|
||||
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
|
||||
if (idx < start + k_) {
|
||||
T t = in[idx];
|
||||
if (is_better_than<greater>(t, val_arr_[i])) {
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better =
|
||||
is_better_than<greater>(t, val_arr_[i], in_idx[idx], idx_arr_[i]);
|
||||
} else {
|
||||
is_better = is_better_than<greater>(t, val_arr_[i]);
|
||||
}
|
||||
if (is_better) {
|
||||
val_arr_[i] = t;
|
||||
idx_arr_[i] = in_idx[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
|
||||
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
|
||||
val_arr_, idx_arr_);
|
||||
}
|
||||
|
||||
__device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
|
||||
@@ -193,7 +275,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
protected:
|
||||
static constexpr int max_arr_len_ = capacity / WARP_SIZE;
|
||||
|
||||
T val_arr_[max_arr_len_];
|
||||
@@ -205,11 +287,11 @@ protected:
|
||||
|
||||
}; // end class WarpSort
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT>
|
||||
class WarpSelect : public WarpSort<capacity, greater, T, idxT> {
|
||||
public:
|
||||
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
|
||||
class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
|
||||
public:
|
||||
__device__ WarpSelect(idxT k, T dummy)
|
||||
: WarpSort<capacity, greater, T, idxT>(k, dummy),
|
||||
: WarpSort<capacity, greater, T, idxT, is_stable>(k, dummy),
|
||||
k_th_(dummy),
|
||||
k_th_lane_((k - 1) % WARP_SIZE) {
|
||||
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
|
||||
@@ -234,7 +316,13 @@ public:
|
||||
}
|
||||
|
||||
__device__ void add(T val, idxT idx) {
|
||||
bool do_add = is_better_than<greater>(val, k_th_);
|
||||
bool do_add;
|
||||
if constexpr (is_stable) {
|
||||
do_add = is_better_than<greater>(val, k_th_, idx, k_th_idx_);
|
||||
} else {
|
||||
do_add = is_better_than<greater>(val, k_th_);
|
||||
}
|
||||
|
||||
uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
|
||||
if (mask == 0) {
|
||||
return;
|
||||
@@ -271,37 +359,52 @@ public:
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
__device__ void set_k_th_() {
|
||||
k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||
if constexpr (is_stable) {
|
||||
k_th_idx_ =
|
||||
__shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void merge_buf_(T val, idxT idx) {
|
||||
BitonicSort<WARP_SIZE, greater, T, idxT>::sort(&val, &idx);
|
||||
BitonicSort<WARP_SIZE, greater, T, idxT, is_stable>::sort(&val, &idx);
|
||||
|
||||
T& old = val_arr_[max_arr_len_ - 1];
|
||||
if (is_better_than<greater>(val, old)) {
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better =
|
||||
is_better_than<greater>(val, old, idx, idx_arr_[max_arr_len_ - 1]);
|
||||
} else {
|
||||
is_better = is_better_than<greater>(val, old);
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
old = val;
|
||||
idx_arr_[max_arr_len_ - 1] = idx;
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, !greater, T, idxT>::merge(val_arr_, idx_arr_);
|
||||
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
|
||||
val_arr_, idx_arr_);
|
||||
|
||||
set_k_th_();
|
||||
}
|
||||
|
||||
using WarpSort<capacity, greater, T, idxT>::max_arr_len_;
|
||||
using WarpSort<capacity, greater, T, idxT>::val_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT>::idx_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT>::lane_;
|
||||
using WarpSort<capacity, greater, T, idxT>::k_;
|
||||
using WarpSort<capacity, greater, T, idxT>::dummy_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::max_arr_len_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::val_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::idx_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::lane_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::k_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::dummy_;
|
||||
|
||||
T* val_smem_;
|
||||
idxT* idx_smem_;
|
||||
int smem_buf_len_ = 0;
|
||||
|
||||
T k_th_;
|
||||
idxT k_th_idx_;
|
||||
int const k_th_lane_;
|
||||
}; // end class WarpSelect
|
||||
} // namespace warp_topk
|
||||
@@ -313,8 +416,8 @@ __device__ void topk_with_k2(T* output,
|
||||
int32_t const lane_id,
|
||||
int const num_experts_per_group) {
|
||||
// Get the top2 per thread
|
||||
T largest = cuda::std::numeric_limits<T>::min();
|
||||
T second_largest = cuda::std::numeric_limits<T>::min();
|
||||
T largest = neg_inf<T>();
|
||||
T second_largest = neg_inf<T>();
|
||||
|
||||
if (num_experts_per_group > WARP_SIZE) {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
@@ -368,8 +471,14 @@ __global__ void topk_with_k2_kernel(T* output,
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
@@ -385,6 +494,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
int64_t const topk,
|
||||
int64_t const num_experts,
|
||||
int64_t const num_experts_per_group,
|
||||
bool const renormalize,
|
||||
double routed_scaling_factor) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
@@ -403,19 +513,29 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
|
||||
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
|
||||
// store the target topk idx
|
||||
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf) + warp_id * topk;
|
||||
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf);
|
||||
T* s_topk_value =
|
||||
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
|
||||
warp_id * topk;
|
||||
s_topk_idx += warp_id * topk;
|
||||
|
||||
T value = cuda::std::numeric_limits<T>::min();
|
||||
T topk_group_value = cuda::std::numeric_limits<T>::min();
|
||||
T value = neg_inf<T>();
|
||||
T topk_group_value = neg_inf<T>();
|
||||
int32_t num_equalto_topkth_group;
|
||||
|
||||
if ((n_group > topk_group) && (case_id < num_tokens)) {
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before
|
||||
// acqbulk because it's ptr arithmetic
|
||||
#endif
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
// calculate group_idx
|
||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
||||
if (lane_id < n_group) {
|
||||
if (lane_id < n_group &&
|
||||
(isfinite(cuda_cast<float, T>(
|
||||
group_scores[lane_id])))) // The check is necessary to avoid
|
||||
// abnormal input
|
||||
{
|
||||
value = group_scores[lane_id];
|
||||
}
|
||||
|
||||
@@ -426,22 +546,23 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = cuda::std::numeric_limits<T>::min();
|
||||
value = neg_inf<T>();
|
||||
}
|
||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||
count_equal_to_top_value = __popc(__ballot_sync(
|
||||
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min())));
|
||||
FULL_WARP_MASK, (value == neg_inf<T>())));
|
||||
}
|
||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t>
|
||||
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
|
||||
/* is_stable */ true>
|
||||
queue((int32_t)topk, neg_inf<T>());
|
||||
|
||||
int count_equalto_topkth_group = 0;
|
||||
bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits<T>::min());
|
||||
if (case_id < num_tokens) {
|
||||
bool if_proceed_next_topk = (topk_group_value != neg_inf<T>());
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
||||
if ((group_scores[i_group] > topk_group_value) ||
|
||||
((group_scores[i_group] == topk_group_value) &&
|
||||
@@ -449,9 +570,11 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
int32_t offset = i_group * num_experts_per_group;
|
||||
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
||||
i += WARP_SIZE) {
|
||||
T candidates = i < num_experts_per_group
|
||||
? scores_with_bias[offset + i]
|
||||
: cuda::std::numeric_limits<T>::min();
|
||||
T candidates =
|
||||
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
|
||||
scores_with_bias[offset + i]))
|
||||
? scores_with_bias[offset + i]
|
||||
: neg_inf<T>();
|
||||
queue.add(candidates, offset + i);
|
||||
}
|
||||
if (group_scores[i_group] == topk_group_value) {
|
||||
@@ -469,7 +592,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
// Load the valid score value
|
||||
// Calculate the summation
|
||||
float topk_sum = 1e-20;
|
||||
if (case_id < num_tokens) {
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i = lane_id;
|
||||
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
|
||||
i += WARP_SIZE) {
|
||||
@@ -478,33 +601,45 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
if (i < topk) {
|
||||
s_topk_value[i] = value;
|
||||
}
|
||||
topk_sum += reduce(tile, value, cg::plus<float>());
|
||||
topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (case_id < num_tokens) {
|
||||
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
|
||||
scores[i] = 0;
|
||||
}
|
||||
}
|
||||
__threadfence();
|
||||
__syncthreads();
|
||||
__syncwarp();
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
|
||||
scores[s_topk_idx[i]] = value;
|
||||
if (if_proceed_next_topk) {
|
||||
if (if_proceed_next_topk) {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float value;
|
||||
if (renormalize) {
|
||||
value = cuda_cast<float, T>(s_topk_value[i]) / topk_sum *
|
||||
routed_scaling_factor;
|
||||
} else {
|
||||
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
|
||||
}
|
||||
scores[s_topk_idx[i]] = value;
|
||||
topk_indices[i] = s_topk_idx[i];
|
||||
topk_values[i] = static_cast<T>(value);
|
||||
topk_values[i] = cuda_cast<T, float>(value);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
topk_indices[i] = i;
|
||||
topk_values[i] = static_cast<float>(1.0f / topk);
|
||||
topk_values[i] = cuda_cast<T, float>(1.0f / topk);
|
||||
}
|
||||
}
|
||||
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
|
||||
// default result.
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
@@ -518,17 +653,24 @@ void invokeNoAuxTc(T* scores,
|
||||
int64_t const n_group,
|
||||
int64_t const topk_group,
|
||||
int64_t const topk,
|
||||
bool const renormalize,
|
||||
double const routed_scaling_factor,
|
||||
cudaStream_t const stream) {
|
||||
int64_t num_cases = num_tokens * n_group;
|
||||
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
|
||||
group_scores,
|
||||
scores_with_bias,
|
||||
num_tokens,
|
||||
num_cases,
|
||||
n_group,
|
||||
num_experts / n_group);
|
||||
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = topk_with_k2_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = false;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
|
||||
num_tokens, num_cases, n_group, num_experts / n_group);
|
||||
|
||||
int64_t topk_with_k_group_num_blocks =
|
||||
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
@@ -536,21 +678,19 @@ void invokeNoAuxTc(T* scores,
|
||||
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
||||
topk);
|
||||
|
||||
group_idx_and_topk_idx_kernel<T><<<topk_with_k_group_num_blocks,
|
||||
BLOCK_SIZE,
|
||||
dynamic_smem_in_bytes,
|
||||
stream>>>(scores,
|
||||
group_scores,
|
||||
topk_values,
|
||||
topk_indices,
|
||||
scores_with_bias,
|
||||
num_tokens,
|
||||
n_group,
|
||||
topk_group,
|
||||
topk,
|
||||
num_experts,
|
||||
num_experts / n_group,
|
||||
routed_scaling_factor);
|
||||
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
|
||||
config.gridDim = topk_with_k_group_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = dynamic_smem_in_bytes;
|
||||
config.stream = stream;
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = false;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
|
||||
topk_values, topk_indices, scores_with_bias, num_tokens,
|
||||
n_group, topk_group, topk, num_experts,
|
||||
num_experts / n_group, renormalize, routed_scaling_factor);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
||||
@@ -564,6 +704,7 @@ void invokeNoAuxTc(T* scores,
|
||||
int64_t const n_group, \
|
||||
int64_t const topk_group, \
|
||||
int64_t const topk, \
|
||||
bool const renormalize, \
|
||||
double const routed_scaling_factor, \
|
||||
cudaStream_t const stream);
|
||||
|
||||
|
@@ -3,6 +3,158 @@
|
||||
|
||||
#include "quantization/common.cuh"
|
||||
|
||||
// adapted from: https://github.com/sgl-project/sglang/blob/v0.5.2rc2/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 1. Warp‑local, no shared memory
|
||||
// • One warp handles one token.
|
||||
// • Eight tokens per 256‑thread CTA.
|
||||
// ---------------------------------------------------------------------------
|
||||
template <typename T, typename DST_DTYPE, int kTokensPerCTA = 8, int kVecSize = 16>
|
||||
__global__ void per_token_quant_fp8_kernel(
|
||||
const T* __restrict__ input,
|
||||
DST_DTYPE* __restrict__ output_q,
|
||||
float* __restrict__ output_s,
|
||||
const float scale_ub,
|
||||
const int64_t hidden_size,
|
||||
const int64_t num_tokens) {
|
||||
const int warp_id = threadIdx.x / WARP_SIZE; // 0‑7 (8 warps)
|
||||
const int lane_id = threadIdx.x & (WARP_SIZE - 1); // 0‑31
|
||||
const int token_id = blockIdx.x * kTokensPerCTA + warp_id;
|
||||
if (token_id >= num_tokens) return;
|
||||
|
||||
// Global tensors for this token
|
||||
const T* token_input = input + token_id * hidden_size;
|
||||
DST_DTYPE* token_output = output_q + token_id * hidden_size;
|
||||
float* token_scale = output_s + token_id;
|
||||
|
||||
//
|
||||
// Pass-1: Perform a warp reduce to find the max_value of a token's hidden_size
|
||||
//
|
||||
float max_value = 0.f;
|
||||
using vec_t = AlignedVector<T, kVecSize>;
|
||||
const int32_t num_vec_elems = hidden_size / kVecSize;
|
||||
|
||||
for (int32_t i = lane_id; i < num_vec_elems; i += WARP_SIZE) {
|
||||
vec_t input_vec;
|
||||
Load(token_input + i * kVecSize, &input_vec);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
max_value = fmaxf(max_value, fabsf(static_cast<float>(input_vec[j])));
|
||||
}
|
||||
}
|
||||
|
||||
float warp_max = warpReduceMax(max_value);
|
||||
if (scale_ub > 0){
|
||||
warp_max = fminf(warp_max, scale_ub);
|
||||
}
|
||||
float scale;
|
||||
scale = warp_max / FP8_E4M3_MAX;
|
||||
// Broadcast scale
|
||||
if (lane_id == 0) {
|
||||
token_scale[0] = scale;
|
||||
}
|
||||
float scale_inv = (scale == 0.f) ? 0.f : 1.0f / scale;
|
||||
|
||||
//
|
||||
// Pass-2: quantize and write back
|
||||
//
|
||||
for (int i = lane_id; i < num_vec_elems; i += WARP_SIZE) {
|
||||
vec_t input_vec;
|
||||
Load(token_input + i * kVecSize, &input_vec);
|
||||
DST_DTYPE output_arr[kVecSize];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
float val = static_cast<float>(input_vec[j]) * scale_inv;
|
||||
val = fmaxf(fminf(val, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
output_arr[j] = static_cast<DST_DTYPE>(val);
|
||||
}
|
||||
if constexpr (kVecSize == 16) {
|
||||
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
|
||||
} else {
|
||||
// Use element-wise copy for vector size 8 to ensure correctness
|
||||
for (int k = 0; k < kVecSize; ++k) {
|
||||
token_output[i * kVecSize + k] = output_arr[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 2. Baseline kernel (1 token / CTA, CUB block reduce)
|
||||
// ---------------------------------------------------------------------------
|
||||
template <typename T, typename DST_DTYPE, int kVecSize = 16>
|
||||
__global__ void per_token_quant_fp8_small_batch_kernel(
|
||||
const T* __restrict__ input,
|
||||
DST_DTYPE* __restrict__ output_q,
|
||||
float* __restrict__ output_s,
|
||||
const float scale_ub,
|
||||
const int64_t hidden_size,
|
||||
const int64_t num_tokens) {
|
||||
const int token_idx = blockIdx.x;
|
||||
if (token_idx >= num_tokens) return;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int block_dim = blockDim.x;
|
||||
|
||||
const T* token_input = input + token_idx * hidden_size;
|
||||
DST_DTYPE* token_output = output_q + token_idx * hidden_size;
|
||||
|
||||
float max_value = 0.0f;
|
||||
|
||||
// Use template parameter for vector size
|
||||
using vec_t = AlignedVector<T, kVecSize>;
|
||||
const int32_t num_vec_elems = hidden_size / kVecSize;
|
||||
|
||||
// Find max using vectorized loads
|
||||
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||
vec_t input_vec;
|
||||
Load(token_input + i * kVecSize, &input_vec);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
float val = static_cast<float>(input_vec[j]);
|
||||
max_value = fmaxf(max_value, fabsf(val));
|
||||
}
|
||||
}
|
||||
|
||||
max_value = blockReduceMax(max_value);
|
||||
if (scale_ub > 0){
|
||||
max_value = fminf(max_value, scale_ub);
|
||||
}
|
||||
__shared__ float scale;
|
||||
if (tid == 0) {
|
||||
scale = max_value / FP8_E4M3_MAX;
|
||||
output_s[token_idx] = scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const float scale_inv = 1.0f / scale;
|
||||
|
||||
// Quantize using vectorized loads
|
||||
for (int32_t i = tid; i < num_vec_elems; i += block_dim) {
|
||||
vec_t input_vec;
|
||||
Load(token_input + i * kVecSize, &input_vec);
|
||||
|
||||
DST_DTYPE output_arr[kVecSize];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kVecSize; ++j) {
|
||||
float val = fmaxf(fminf(static_cast<float>(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX);
|
||||
output_arr[j] = static_cast<DST_DTYPE>(val);
|
||||
}
|
||||
|
||||
if constexpr (kVecSize == 16) {
|
||||
*(uint4*)(token_output + i * kVecSize) = *(uint4*)output_arr;
|
||||
} else {
|
||||
// Use element-wise copy for vector size 8 to ensure correctness
|
||||
for (int k = 0; k < kVecSize; ++k) {
|
||||
token_output[i * kVecSize + k] = output_arr[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
@@ -179,39 +331,78 @@ void DynamicPerTokenScaledFp8Quant(paddle::Tensor &out, // [..., d]
|
||||
auto rank = input.dims().size();
|
||||
int const hidden_size = input.dims()[rank - 1];
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
cudaStream_t stream = input.stream();
|
||||
|
||||
if (hidden_size % 8 == 0){
|
||||
int device = 0;
|
||||
cudaGetDevice(&device);
|
||||
int sm_count = 0;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device);
|
||||
const int TOKENS_PER_CTA = 8;
|
||||
const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA);
|
||||
const bool use_vec16 = (hidden_size % 16 == 0);
|
||||
DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
|
||||
if (use_warp_kernel) {
|
||||
// -------- warp‑local ---------------------------------------------------
|
||||
constexpr int THREADS = TOKENS_PER_CTA * WARP_SIZE; // 256
|
||||
dim3 grid((num_tokens + TOKENS_PER_CTA - 1) / TOKENS_PER_CTA);
|
||||
dim3 block(THREADS);
|
||||
|
||||
if (use_vec16) {
|
||||
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 16><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||
reinterpret_cast<float*>(scales.data<float>()),
|
||||
scale_ub,
|
||||
hidden_size,
|
||||
num_tokens);
|
||||
} else {
|
||||
per_token_quant_fp8_kernel<scalar_t, __nv_fp8_e4m3, TOKENS_PER_CTA, 8><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||
reinterpret_cast<float*>(scales.data<float>()),
|
||||
scale_ub,
|
||||
hidden_size,
|
||||
num_tokens);
|
||||
}
|
||||
} else {
|
||||
// -------- baseline -----------------------------------------------------
|
||||
constexpr int THREADS = 256;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(THREADS);
|
||||
|
||||
if (use_vec16) {
|
||||
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3, 16><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||
reinterpret_cast<float*>(scales.data<float>()),
|
||||
scale_ub,
|
||||
hidden_size,
|
||||
num_tokens);
|
||||
} else {
|
||||
per_token_quant_fp8_small_batch_kernel<scalar_t, __nv_fp8_e4m3, 8><<<grid, block, 0, stream>>>(
|
||||
reinterpret_cast<const scalar_t*>(input.data<scalar_t>()),
|
||||
reinterpret_cast<__nv_fp8_e4m3*>(out.data<fp8_t>()),
|
||||
reinterpret_cast<float*>(scales.data<float>()),
|
||||
scale_ub,
|
||||
hidden_size,
|
||||
num_tokens);
|
||||
}
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(std::min(hidden_size, 1024));
|
||||
|
||||
cudaStream_t stream = input.stream();
|
||||
DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
|
||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||
input.data<scalar_t>(), scale_ub,
|
||||
hidden_size);
|
||||
});
|
||||
|
||||
switch (input.dtype()) {
|
||||
case paddle::DataType::FLOAT32: {
|
||||
using scalar_t = float;
|
||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||
input.data<scalar_t>(), scale_ub,
|
||||
hidden_size);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::FLOAT16: {
|
||||
using scalar_t = phi::dtype::float16;
|
||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||
input.data<scalar_t>(), scale_ub,
|
||||
hidden_size);
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::BFLOAT16: {
|
||||
using scalar_t = phi::dtype::bfloat16;
|
||||
fastdeploy::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(out.data<fp8_t>(), scales.data<float>(),
|
||||
input.data<scalar_t>(), scale_ub,
|
||||
hidden_size);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16].");
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(static_scaled_fp8_quant)
|
||||
|
@@ -15,72 +15,31 @@
|
||||
#include "helper.h"
|
||||
|
||||
__global__ void recover_decode_task(bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
int thread_idx = threadIdx.x;
|
||||
if (thread_idx < bsz) {
|
||||
if(is_block_step[thread_idx] == true) {
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
if (block_table_now[step_seq_lens_decoder[thread_idx] / block_size] != -1) {
|
||||
// can be recovered for decoding
|
||||
is_block_step[thread_idx] = false;
|
||||
seq_lens_this_time[thread_idx]= 1;
|
||||
stop_flags[thread_idx] = false;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void recover_spec_decode_task(bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
int64_t *draft_tokens,
|
||||
const int64_t *step_draft_tokens,
|
||||
const int *step_seq_lens_this_time,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size,
|
||||
const int draft_tokens_len,
|
||||
const int num_extra_tokens) {
|
||||
int thread_idx = threadIdx.x;
|
||||
if (thread_idx < bsz) {
|
||||
if(is_block_step[thread_idx] == true) {
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
int max_possible_block_idx = (step_seq_lens_decoder[thread_idx] + num_extra_tokens) / block_size;
|
||||
max_possible_block_idx = min(max_possible_block_idx, block_num_per_seq);
|
||||
if (block_table_now[max_possible_block_idx] != -1) {
|
||||
// can be recovered for decoding
|
||||
int64_t *draft_tokens_now = draft_tokens + thread_idx * draft_tokens_len;
|
||||
const int64_t *step_draft_tokens_now = step_draft_tokens + thread_idx * draft_tokens_len;
|
||||
is_block_step[thread_idx] = false;
|
||||
seq_lens_this_time[thread_idx] = step_seq_lens_this_time[thread_idx];
|
||||
stop_flags[thread_idx] = false;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
|
||||
for (int i = 0; i < seq_lens_this_time[thread_idx]; i++) {
|
||||
draft_tokens_now[i] = step_draft_tokens_now[i];
|
||||
// can be recovered for decoding
|
||||
is_block_step[thread_idx] = false;
|
||||
seq_lens_this_time[thread_idx]= 1;
|
||||
stop_flags[thread_idx] = false;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
@@ -88,11 +47,7 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::optional<paddle::Tensor> &draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
|
||||
const int block_size,
|
||||
const int max_draft_tokens) {
|
||||
const int block_size) {
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(seq_lens_this_time.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
@@ -101,38 +56,17 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
#endif
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
if (draft_tokens) {
|
||||
const int draft_tokens_len = draft_tokens.get_ptr()->shape()[1];
|
||||
recover_spec_decode_task<<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<int64_t *>(draft_tokens.get_ptr()->data<int64_t>()),
|
||||
step_draft_tokens.get_ptr()->data<int64_t>(),
|
||||
step_seq_lens_this_time.get_ptr()->data<int>(),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size,
|
||||
draft_tokens_len,
|
||||
max_draft_tokens * 2 + 1);
|
||||
|
||||
} else {
|
||||
recover_decode_task<<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
}
|
||||
recover_decode_task<<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(recover_decode_task)
|
||||
@@ -142,11 +76,8 @@ PD_BUILD_STATIC_OP(recover_decode_task)
|
||||
"seq_lens_decoder",
|
||||
"step_seq_lens_decoder",
|
||||
"block_tables",
|
||||
"is_block_step",
|
||||
paddle::Optional("draft_tokens"),
|
||||
paddle::Optional("step_draft_tokens"),
|
||||
paddle::Optional("step_seq_lens_this_time")})
|
||||
.Attrs({"block_size: int", "max_draft_tokens: int"})
|
||||
"is_block_step"})
|
||||
.Attrs({"block_size: int"})
|
||||
.Outputs({"seq_lens_this_time_out",
|
||||
"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
|
@@ -15,48 +15,7 @@
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
|
||||
#define DISPATCH_BLOCKSIZE(BLOCK_SIZE, ...) \
|
||||
do { \
|
||||
constexpr int BlockSize = BLOCK_SIZE; \
|
||||
__VA_ARGS__; \
|
||||
} while (0)
|
||||
|
||||
#define DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, ...) \
|
||||
do { \
|
||||
if (truncate_first_token) { \
|
||||
constexpr bool TRUNCATE_FIRST_TOKEN = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool TRUNCATE_FIRST_TOKEN = false; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, ...) \
|
||||
do { \
|
||||
if (kvcache_scheduler_v1) { \
|
||||
constexpr bool KVCACHE_SCHEDULER_V1 = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool KVCACHE_SCHEDULER_V1 = false; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, ...) \
|
||||
do { \
|
||||
if (splitwise_prefill) { \
|
||||
constexpr bool SPLITWISE_PREFILL = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool SPLITWISE_PREFILL = false; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
template <int THREADBLOCK_SIZE, bool TRUNCATE_FIRST_TOKEN, bool KVCACHE_SCHEDULER_V1>
|
||||
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
|
||||
__global__ void process_splitwise_prefill(
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
@@ -66,7 +25,6 @@ __global__ void process_splitwise_prefill(
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
@@ -100,7 +58,7 @@ __global__ void process_splitwise_prefill(
|
||||
stop_flags[tid] = false;
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
int position = seq_len_encoder;
|
||||
if (TRUNCATE_FIRST_TOKEN) {
|
||||
if (TRCUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder;
|
||||
} else {
|
||||
@@ -126,7 +84,7 @@ __global__ void process_splitwise_prefill(
|
||||
|
||||
|
||||
|
||||
template <int THREADBLOCK_SIZE, bool TRUNCATE_FIRST_TOKEN, bool KVCACHE_SCHEDULER_V1>
|
||||
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
|
||||
__global__ void draft_model_preprocess_kernel(
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
@@ -136,7 +94,6 @@ __global__ void draft_model_preprocess_kernel(
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
@@ -177,26 +134,14 @@ __global__ void draft_model_preprocess_kernel(
|
||||
base_model_draft_tokens_now[i] = -1;
|
||||
}
|
||||
|
||||
// 1. process block_step situation
|
||||
// -- In v0 mode, block_step will drop mtp query.
|
||||
// -- In v1 mode, block_step will continue to infer.
|
||||
if constexpr(KVCACHE_SCHEDULER_V1) {
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
stop_flags[tid] = true;
|
||||
is_block_step[tid] = true;
|
||||
// Need to continue infer
|
||||
}
|
||||
} else {
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
batch_drop[tid] = true;
|
||||
stop_flags[tid] = true;
|
||||
}
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
batch_drop[tid] = true;
|
||||
stop_flags[tid] = true;
|
||||
}
|
||||
|
||||
// 2. process normal query, not in any special case.
|
||||
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
|
||||
not_stop_flag = 1;
|
||||
// prefill generation
|
||||
// 1. first token
|
||||
if (seq_lens_encoder[tid] > 0) {
|
||||
// Can be extended to first few tokens
|
||||
int seq_len_encoder = seq_lens_encoder[tid];
|
||||
@@ -204,20 +149,14 @@ __global__ void draft_model_preprocess_kernel(
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
pre_ids_now[0] = base_model_first_token;
|
||||
int position = seq_len_encoder;
|
||||
if (TRUNCATE_FIRST_TOKEN) {
|
||||
if (TRCUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder;
|
||||
} else {
|
||||
input_ids_now[position] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder + 1;
|
||||
}
|
||||
} else { // decode generation
|
||||
if constexpr (KVCACHE_SCHEDULER_V1) {
|
||||
// 3. try to recover mtp infer in V1 mode
|
||||
if (!base_model_is_block_step[tid] && is_block_step[tid]) {
|
||||
is_block_step[tid] = false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (stop_flags[tid]) {
|
||||
stop_flags[tid] = false;
|
||||
// TODO: check
|
||||
@@ -250,8 +189,99 @@ __global__ void draft_model_preprocess_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <bool TRCUNCATE_FIRST_TOKEN>
|
||||
void DispatchRunner(
|
||||
const cudaStream_t& stream,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool splitwise_prefill) {
|
||||
constexpr int BlockSize = 512;
|
||||
if (splitwise_prefill) {
|
||||
process_splitwise_prefill<BlockSize, TRCUNCATE_FIRST_TOKEN>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
} else {
|
||||
draft_model_preprocess_kernel<BlockSize, TRCUNCATE_FIRST_TOKEN>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
}
|
||||
}
|
||||
|
||||
void DispatchTokenMode(
|
||||
const cudaStream_t &stream,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
@@ -261,7 +291,6 @@ void DispatchRunner(
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
@@ -281,79 +310,75 @@ void DispatchRunner(
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1) {
|
||||
DISPATCH_BLOCKSIZE(512, {
|
||||
DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, {
|
||||
DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, {
|
||||
DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, {
|
||||
if constexpr (SPLITWISE_PREFILL) {
|
||||
process_splitwise_prefill<BlockSize, TRUNCATE_FIRST_TOKEN, KVCACHE_SCHEDULER_V1>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
} else {
|
||||
draft_model_preprocess_kernel<BlockSize, TRUNCATE_FIRST_TOKEN, KVCACHE_SCHEDULER_V1>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
const bool splitwise_prefill) {
|
||||
if (truncate_first_token) {
|
||||
DispatchRunner<true>(
|
||||
stream,
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
} else {
|
||||
DispatchRunner<false>(
|
||||
stream,
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& stop_flags,
|
||||
@@ -362,7 +387,6 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
@@ -376,8 +400,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int num_model_step,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1) {
|
||||
const bool splitwise_prefill) {
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
int accept_tokens_len = accept_tokens.shape()[1];
|
||||
int input_ids_len = input_ids.shape()[1];
|
||||
@@ -389,38 +412,36 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
auto not_need_stop_gpu =
|
||||
not_need_stop.copy_to(seq_lens_this_time.place(), false);
|
||||
|
||||
DispatchRunner(
|
||||
cu_stream,
|
||||
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool*>(is_block_step.data<bool>()),
|
||||
const_cast<bool*>(batch_drop.data<bool>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_decoder.data<int>(),
|
||||
base_model_step_idx.data<int64_t>(),
|
||||
base_model_stop_flags.data<bool>(),
|
||||
base_model_is_block_step.data<bool>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill,
|
||||
kvcache_scheduler_v1);
|
||||
DispatchTokenMode(
|
||||
cu_stream,
|
||||
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool*>(batch_drop.data<bool>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_decoder.data<int>(),
|
||||
base_model_step_idx.data<int64_t>(),
|
||||
base_model_stop_flags.data<bool>(),
|
||||
base_model_is_block_step.data<bool>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill);
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
@@ -438,7 +459,6 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"seq_lens_decoder",
|
||||
"step_idx",
|
||||
"not_need_stop",
|
||||
"is_block_step",
|
||||
"batch_drop",
|
||||
"pre_ids",
|
||||
"accept_tokens",
|
||||
@@ -460,7 +480,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"not_need_stop_out",
|
||||
"batch_drop_out",
|
||||
"pre_ids_out"})
|
||||
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool", "kvcache_scheduler_v1: bool"})
|
||||
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
|
@@ -1,176 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void speculate_schedula_cache(
|
||||
const int64_t *draft_tokens,
|
||||
int *block_tables,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *step_draft_tokens,
|
||||
int *step_seq_lens_this_time,
|
||||
int *accept_num,
|
||||
int64_t *accept_tokens,
|
||||
bool *is_block_step,
|
||||
bool *not_need_stop,
|
||||
const int64_t *stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
const int draft_tokens_len,
|
||||
const int accept_tokens_len,
|
||||
const int block_size,
|
||||
const int block_num_per_seq) {
|
||||
const int bid = threadIdx.x;
|
||||
int stop_flag_now_int = 0;
|
||||
if (bid < real_bsz) {
|
||||
if (!stop_flags[bid]) {
|
||||
const int64_t *draft_tokens_now = draft_tokens + bid * draft_tokens_len;
|
||||
int64_t *step_draft_tokens_now = step_draft_tokens + bid * draft_tokens_len;
|
||||
int *block_table_now = block_tables + bid * block_num_per_seq;
|
||||
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
|
||||
const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
|
||||
if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) {
|
||||
is_block_step[bid] = true;
|
||||
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
|
||||
seq_lens_this_time[bid] = 0;
|
||||
stop_flags[bid] = true;
|
||||
stop_flag_now_int = 1;
|
||||
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
|
||||
seq_lens_decoder[bid] = 0;
|
||||
accept_num[bid] = 0;
|
||||
for (int i = 0; i < accept_tokens_len; i++) {
|
||||
accept_tokens_now[i] = -1;
|
||||
}
|
||||
for (int i = 0; i < draft_tokens_len; i++) {
|
||||
step_draft_tokens_now[i] = draft_tokens_now[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
} else if (bid >= real_bsz && bid < max_bsz) {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
__syncthreads();
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// printf("stop_sum %d \n", stop_sum);
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &step_draft_tokens,
|
||||
const paddle::Tensor &step_seq_lens_this_time,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const int block_size,
|
||||
const int max_draft_tokens) {
|
||||
const int real_bsz = seq_lens_this_time.shape()[0];
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int accept_tokens_len = accept_tokens.shape()[1];
|
||||
const int draft_token_len = draft_tokens.shape()[1];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
|
||||
constexpr int BlockSize = 512;
|
||||
const int max_next_step_tokens = 2 * max_draft_tokens + 2;
|
||||
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
speculate_schedula_cache<BlockSize><<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
|
||||
draft_tokens.data<int64_t>(),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(step_draft_tokens.data<int64_t>()),
|
||||
const_cast<int *>(step_seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens,
|
||||
draft_token_len,
|
||||
accept_tokens_len,
|
||||
block_size,
|
||||
block_num_per_seq
|
||||
);
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_schedule_cache)
|
||||
.Inputs({"draft_tokens",
|
||||
"block_tables",
|
||||
"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_decoder",
|
||||
"step_seq_lens_decoder",
|
||||
"step_draft_tokens",
|
||||
"step_seq_lens_this_time",
|
||||
"accept_num",
|
||||
"accept_tokens",
|
||||
"is_block_step",
|
||||
"not_need_stop",
|
||||
"stop_nums"})
|
||||
.Attrs({"block_size: int", "max_draft_tokens: int"})
|
||||
.Outputs({"draft_tokens_out",
|
||||
"block_tables_out",
|
||||
"stop_flags_out",
|
||||
"seq_lens_this_time_out",
|
||||
"seq_lens_decoder_out",
|
||||
"step_seq_lens_decoder_out",
|
||||
"step_draft_tokens_out",
|
||||
"step_seq_lens_this_time_out",
|
||||
"accept_num_out",
|
||||
"accept_tokens_out",
|
||||
"is_block_step_out",
|
||||
"not_need_stop_out"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"block_tables", "block_tables_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
|
||||
{"step_draft_tokens", "step_draft_tokens_out"},
|
||||
{"step_seq_lens_this_time", "step_seq_lens_this_time_out"},
|
||||
{"accept_num", "accept_num_out"},
|
||||
{"accept_tokens", "accept_tokens_out"},
|
||||
{"is_block_step", "is_block_step_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateScheduleCache));
|
@@ -38,20 +38,14 @@ __device__ int64_t topp_sampling_kernel(const int64_t *candidate_ids,
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
float sum_scores = 0.0f;
|
||||
for (int i = 0; i < candidate_len; i++) {
|
||||
sum_scores += candidate_scores[i];
|
||||
}
|
||||
float tgt_topp = sum_scores < topp ? sum_scores : topp;
|
||||
|
||||
sum_scores = 0.0f;
|
||||
float rand_top_p = curand_uniform(dev_curand_states + tid) * tgt_topp;
|
||||
float rand_top_p = curand_uniform(dev_curand_states + tid) * topp;
|
||||
for (int i = 0; i < candidate_len; i++) {
|
||||
sum_scores += candidate_scores[i];
|
||||
if (rand_top_p <= sum_scores) {
|
||||
return candidate_ids[i];
|
||||
return candidate_ids[i];
|
||||
}
|
||||
}
|
||||
return candidate_ids[0];
|
||||
return candidate_ids[0];
|
||||
}
|
||||
|
||||
__global__ void setup_kernel(curandState_t *state, const uint64_t seed,
|
||||
|
@@ -467,9 +467,6 @@ __global__ void KeMatrixTopPBeamTopKFt(
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (top_p_value == 1.0 && actual_candidates_lens[token_id] == 0){
|
||||
actual_candidates_lens[token_id] = max_cadidate_len;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
71
custom_ops/gpu_ops/unset_data_ipc.cu
Normal file
@@ -0,0 +1,71 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
#include "cuda_multiprocess.h"
|
||||
|
||||
#if !defined(_WIN32)
|
||||
#include <errno.h>
|
||||
#include <string.h>
|
||||
#include <fcntl.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#endif
|
||||
|
||||
// 可选:仅删除/解除共享内存命名对象(不依赖之前保存的 addr/fd)
|
||||
static inline int sharedMemoryUnlinkByName(const char* name) {
|
||||
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
|
||||
// Windows 上没有 shm_unlink 语义。命名对象在最后一个句柄关闭后消失。
|
||||
// 这里做“尽力而为”:尝试打开后立即关闭,减少一次引用。
|
||||
HANDLE hMap = OpenFileMappingA(FILE_MAP_ALL_ACCESS, FALSE, name);
|
||||
if (hMap) {
|
||||
CloseHandle(hMap);
|
||||
return 0;
|
||||
}
|
||||
// 已经不存在也算成功
|
||||
return 0;
|
||||
#else
|
||||
// POSIX: 移除名字,未来不可再 open;已映射区仍存活直至 munmap
|
||||
if (shm_unlink(name) != 0) {
|
||||
if (errno == ENOENT) return 0; // 不存在视作成功
|
||||
return errno;
|
||||
}
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
void UnsetDataIpc(const paddle::Tensor& tmp_input,
|
||||
const std::string& shm_name,
|
||||
bool close_ipc,
|
||||
bool unlink_shm) {
|
||||
// 1) 关闭消费者导入的 IPC 映射(仅当 close_ipc=true 且该指针确为 OpenMemHandle 得来)
|
||||
if (close_ipc) {
|
||||
void* ptr = const_cast<void*>(tmp_input.data());
|
||||
checkCudaErrors(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
|
||||
// 2) 解除共享内存命名对象(仅处理“名字”,不保证解除旧映射)
|
||||
if (unlink_shm) {
|
||||
int rc = sharedMemoryUnlinkByName(shm_name.c_str());
|
||||
if (rc != 0) {
|
||||
PD_THROW("Unlink shared memory failed: name=%s, err=%d",
|
||||
shm_name.c_str(), rc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(unset_data_ipc)
|
||||
.Inputs({"tmp_input"})
|
||||
.Attrs({"shm_name: std::string", "close_ipc: bool", "unlink_shm: bool"})
|
||||
.SetKernelFn(PD_KERNEL(UnsetDataIpc));
|
@@ -37,6 +37,52 @@ def load_module_from_path(module_name, path):
|
||||
return module
|
||||
|
||||
|
||||
def update_git_repo():
|
||||
try:
|
||||
print("update third party repo...", flush=True)
|
||||
original_dir = os.getcwd()
|
||||
submodule_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
third_party_path = os.path.join(submodule_dir, "third_party")
|
||||
root_path = Path(third_party_path)
|
||||
|
||||
# check if third_party is empty
|
||||
update_third_party = False
|
||||
for dirpath in root_path.iterdir():
|
||||
if dirpath.is_dir():
|
||||
has_content = any(dirpath.iterdir())
|
||||
if not has_content:
|
||||
update_third_party = True
|
||||
|
||||
if update_third_party:
|
||||
os.chdir(submodule_dir)
|
||||
subprocess.run(
|
||||
"git submodule sync --recursive && git submodule update --init --recursive",
|
||||
shell=True,
|
||||
check=True,
|
||||
text=True,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"\033[33m[===WARNING===]third_party directory already exists, skip clone and update.\033[0m",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# apply deep gemm patch
|
||||
deep_gemm_dir = "third_party/DeepGEMM"
|
||||
dst_path = os.path.join(submodule_dir, deep_gemm_dir)
|
||||
patch = "0001-DeepGEMM-95e81b3.patch"
|
||||
patch_source = os.path.join(submodule_dir, patch)
|
||||
patch_destination = os.path.join(dst_path, patch)
|
||||
if not os.path.exists(patch_destination):
|
||||
shutil.copy(patch_source, patch_destination)
|
||||
apply_cmd = ["git", "apply", patch]
|
||||
os.chdir(dst_path)
|
||||
subprocess.run(apply_cmd, check=True)
|
||||
os.chdir(original_dir)
|
||||
except subprocess.CalledProcessError:
|
||||
raise Exception("Git submodule update and apply patch failed. Maybe network connection is poor.")
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent
|
||||
|
||||
# cannot import envs directly because it depends on fastdeploy,
|
||||
@@ -46,6 +92,8 @@ envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "fastdeploy", "envs.
|
||||
archs = json.loads(envs.FD_BUILDING_ARCS)
|
||||
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
|
||||
|
||||
update_git_repo()
|
||||
|
||||
|
||||
def download_and_extract(url, destination_directory):
|
||||
"""
|
||||
@@ -78,52 +126,6 @@ def download_and_extract(url, destination_directory):
|
||||
print(f"Error extracting file: {e}")
|
||||
|
||||
|
||||
def clone_git_repo(version, repo_url, destination_path):
|
||||
"""
|
||||
Clone git repo to destination path.
|
||||
"""
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"-b",
|
||||
version,
|
||||
"--single-branch",
|
||||
repo_url,
|
||||
destination_path,
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
|
||||
def process_git_repo(cur_path, dst_path, commit_id=None, patch=None):
|
||||
"""
|
||||
reset git repo to destination commit and apply patch.
|
||||
"""
|
||||
if commit_id is not None:
|
||||
reset_cmd = ["git", "reset", "--hard", commit_id]
|
||||
if patch is not None:
|
||||
patch_source = os.path.join(cur_path, patch)
|
||||
patch_destination = os.path.join(dst_path, patch)
|
||||
shutil.copy(patch_source, patch_destination)
|
||||
apply_cmd = ["git", "apply", patch]
|
||||
|
||||
try:
|
||||
os.chdir(dst_path)
|
||||
if commit_id is not None:
|
||||
subprocess.run(reset_cmd, check=True)
|
||||
if patch is not None:
|
||||
subprocess.run(apply_cmd, check=True)
|
||||
os.chdir(cur_path)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
|
||||
def get_sm_version(archs):
|
||||
"""
|
||||
Get sm version of paddle.
|
||||
@@ -191,13 +193,6 @@ def find_end_files(directory, end_str):
|
||||
if paddle.is_compiled_with_rocm():
|
||||
# NOTE(@duanyanhui): paddle.is_compiled_with_cuda() returns True when paddle compiled with rocm.
|
||||
# so we need to check if paddle compiled with rocm at first.
|
||||
json_dir = "third_party/nlohmann_json"
|
||||
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
||||
if not os.path.exists(json_dir):
|
||||
os.makedirs(json_dir)
|
||||
clone_git_repo("v3.11.3", "https://bgithub.xyz/nlohmann/json.git", json_dir)
|
||||
if not os.listdir(json_dir):
|
||||
raise ValueError("Git clone nlohmann_json failed!")
|
||||
sources = [
|
||||
"gpu_ops/save_with_output_msg.cc",
|
||||
"gpu_ops/get_output.cc",
|
||||
@@ -213,6 +208,7 @@ if paddle.is_compiled_with_rocm():
|
||||
"gpu_ops/rebuild_padding.cu",
|
||||
"gpu_ops/step.cu",
|
||||
"gpu_ops/set_data_ipc.cu",
|
||||
"gpu_ops/unset_data_ipc.cu",
|
||||
"gpu_ops/moe/tritonmoe_preprocess.cu",
|
||||
"gpu_ops/step_system_cache.cu",
|
||||
"gpu_ops/get_output_ep.cc",
|
||||
@@ -283,6 +279,7 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/beam_search_softmax.cu",
|
||||
"gpu_ops/rebuild_padding.cu",
|
||||
"gpu_ops/set_data_ipc.cu",
|
||||
"gpu_ops/unset_data_ipc.cu",
|
||||
"gpu_ops/read_data_ipc.cu",
|
||||
"gpu_ops/enforce_generation.cu",
|
||||
"gpu_ops/dequant_int8.cu",
|
||||
@@ -316,28 +313,6 @@ elif paddle.is_compiled_with_cuda():
|
||||
"gpu_ops/ipc_sent_key_value_cache_by_remote_ptr.cu",
|
||||
]
|
||||
|
||||
cutlass_dir = "third_party/cutlass"
|
||||
if not os.path.exists(cutlass_dir) or not os.listdir(cutlass_dir):
|
||||
if not os.path.exists(cutlass_dir):
|
||||
os.makedirs(cutlass_dir)
|
||||
clone_git_repo("v3.8.0", "https://github.com/NVIDIA/cutlass.git", cutlass_dir)
|
||||
if not os.listdir(cutlass_dir):
|
||||
raise ValueError("Git clone cutlass failed!")
|
||||
|
||||
# deep gemm
|
||||
deep_gemm_dir = "third_party/DeepGEMM"
|
||||
if not os.path.exists(deep_gemm_dir) or not os.listdir(deep_gemm_dir):
|
||||
if not os.path.exists(deep_gemm_dir):
|
||||
os.makedirs(deep_gemm_dir)
|
||||
clone_git_repo("main", "https://github.com/deepseek-ai/DeepGEMM.git", deep_gemm_dir)
|
||||
if not os.listdir(deep_gemm_dir):
|
||||
raise ValueError("Git clone DeepGEMM failed!")
|
||||
cur_path = os.path.dirname(os.path.abspath(__file__))
|
||||
dst_path = os.path.join(cur_path, deep_gemm_dir)
|
||||
commit_id = "95e81b3dd6704e279e5f4757c5b94776ac988a8d"
|
||||
patch = "0001-DeepGEMM-95e81b3.patch"
|
||||
process_git_repo(cur_path, dst_path, commit_id, patch)
|
||||
|
||||
dg_third_party_include_dirs = (
|
||||
"third_party/cutlass/include/cute",
|
||||
"third_party/cutlass/include/cutlass",
|
||||
@@ -365,14 +340,6 @@ elif paddle.is_compiled_with_cuda():
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to copy from {src_dir} to {dst_dir}: {e}")
|
||||
|
||||
json_dir = "third_party/nlohmann_json"
|
||||
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
||||
if not os.path.exists(json_dir):
|
||||
os.makedirs(json_dir)
|
||||
clone_git_repo("v3.11.3", "https://github.com/nlohmann/json.git", json_dir)
|
||||
if not os.listdir(json_dir):
|
||||
raise ValueError("Git clone nlohmann_json failed!")
|
||||
|
||||
cc_compile_args = []
|
||||
nvcc_compile_args = get_gencode_flags(archs)
|
||||
nvcc_compile_args += ["-DPADDLE_DEV"]
|
||||
@@ -593,13 +560,6 @@ elif paddle.is_compiled_with_custom_device("gcu"):
|
||||
)
|
||||
elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
maca_path = os.getenv("MACA_PATH", "/opt/maca")
|
||||
json_dir = "third_party/nlohmann_json"
|
||||
if not os.path.exists(json_dir) or not os.listdir(json_dir):
|
||||
if not os.path.exists(json_dir):
|
||||
os.makedirs(json_dir)
|
||||
clone_git_repo("v3.11.3", "https://gitee.com/learnlov/mirrors_nlohmann_json.git", json_dir)
|
||||
if not os.listdir(json_dir):
|
||||
raise ValueError("Git clone nlohmann_json failed!")
|
||||
sources = [
|
||||
"gpu_ops/update_inputs_v1.cu",
|
||||
"gpu_ops/save_with_output_msg.cc",
|
||||
|
1
custom_ops/third_party/DeepGEMM
vendored
Submodule
1
custom_ops/third_party/cutlass
vendored
Submodule
1
custom_ops/third_party/nlohmann_json
vendored
Submodule
@@ -1,6 +1,6 @@
|
||||
FROM ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:2.1.0
|
||||
ARG PADDLE_VERSION=3.1.1
|
||||
ARG FD_VERSION=2.1.0
|
||||
FROM ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:2.2.0
|
||||
ARG PADDLE_VERSION=3.2.0
|
||||
ARG FD_VERSION=2.2.0
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
|
BIN
docs/assets/images/favicon.ico
Normal file
After Width: | Height: | Size: 4.2 KiB |
BIN
docs/assets/images/logo.jpg
Normal file
After Width: | Height: | Size: 14 KiB |
@@ -19,22 +19,23 @@ The minimum number of GPUs required to deploy `ERNIE-4.5-0.3B` on the following
|
||||
### 1.2 Install fastdeploy
|
||||
- Installation: For detail, please refer to [Fastdeploy Installation](../get_started/installation/README.md).
|
||||
|
||||
- Model Download,For detail, please refer to [Supported Models](../supported_models.md). **Please note that models with Paddle suffix need to be used for Fastdeploy**:
|
||||
- Model Download,For detail, please refer to [Supported Models](../supported_models.md).
|
||||
|
||||
## 2.How to Use
|
||||
### 2.1 Basic: Launching the Service
|
||||
Start the service by following command:
|
||||
```bash
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model baidu/ERNIE-4.5-0.3B-Paddle \
|
||||
--tensor-parallel-size 1 \
|
||||
--quantization wint4 \
|
||||
--max-model-len 32768 \
|
||||
--max-num-seqs 128
|
||||
--max-num-seqs 128 \
|
||||
--load_choices "default_v1"
|
||||
```
|
||||
- `--quantization`: indicates the quantization strategy used by the model. Different quantization strategies will result in different performance and accuracy of the model. It could be one of `wint8` / `wint4` / `block_wise_fp8`(Hopper is needed).
|
||||
- `--max-model-len`: Indicates the maximum number of tokens supported by the currently deployed service. The larger the value, the longer the context length the model can support, but the more GPU memory is occupied, which may affect the concurrency.
|
||||
- `--load_choices`: indicates the version of the loader. "default_v1" means enabling the v1 version of the loader, which has faster loading speed and less memory usage.
|
||||
|
||||
For more parameter meanings and default settings, see [FastDeploy Parameter Documentation](../parameters.md)。
|
||||
|
||||
@@ -42,17 +43,14 @@ For more parameter meanings and default settings, see [FastDeploy Parameter Docu
|
||||
#### 2.2.1 Correctly set parameters that match the application scenario
|
||||
Evaluate average input length, average output length, and maximum context length
|
||||
- Set max-model-len according to the maximum context length. For example, if the average input length is 1000 and the output length is 30000, then it is recommended to set it to 32768
|
||||
- **Enable the service management global block**
|
||||
|
||||
```
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
```
|
||||
|
||||
#### 2.2.2 Prefix Caching
|
||||
**Idea:** The core idea of Prefix Caching is to avoid repeated calculations by caching the intermediate calculation results of the input sequence (KV Cache), thereby speeding up the response speed of multiple requests with the same prefix. For details, refer to [prefix-cache](../features/prefix_caching.md)
|
||||
|
||||
**How to enable:**
|
||||
Add the following lines to the startup parameters, where `--enable-prefix-caching` enables prefix caching, and `--swap-space` enables CPU cache in addition to GPU cache. The size is GB and should be adjusted according to the actual situation of the machine.
|
||||
Since version 2.2 (including the develop branch), Prefix Caching has been enabled by default.
|
||||
|
||||
For versions 2.1 and earlier, you need to enable it manually by adding following lines to the startup parameters, where `--enable-prefix-caching` enables prefix caching, and `--swap-space` enables CPU cache in addition to GPU cache. The size is GB and should be adjusted according to the actual situation of the machine. The recommended value is `(total machine memory - model size) * 20%`. If the service fails to start because other programs are occupying memory, try reducing the `--swap-space` value.
|
||||
```
|
||||
--enable-prefix-caching
|
||||
--swap-space 50
|
||||
@@ -61,7 +59,10 @@ Add the following lines to the startup parameters, where `--enable-prefix-cachin
|
||||
#### 2.2.3 Chunked Prefill
|
||||
**Idea:** This strategy is adopted to split the prefill stage request into small-scale sub-chunks, and execute them in batches mixed with the decode request. This can better balance the computation-intensive (Prefill) and memory-intensive (Decode) operations, optimize GPU resource utilization, reduce the computational workload and memory usage of a single Prefill, thereby reducing the peak memory usage and avoiding the problem of insufficient memory. For details, please refer to [Chunked Prefill](../features/chunked_prefill.md)
|
||||
|
||||
**How to enable:** Add the following lines to the startup parameters
|
||||
**How to enable:**
|
||||
Since version 2.2 (including the develop branch), Chunked Prefill has been enabled by default.
|
||||
|
||||
For versions 2.1 and earlier, you need to enable it manually by adding
|
||||
```
|
||||
--enable-chunked-prefill
|
||||
```
|
||||
@@ -79,7 +80,7 @@ Notes:
|
||||
|
||||
- Usually, no additional parameters need to be set, but CUDAGraph will generate some additional memory overhead, which may need to be adjusted in some scenarios with limited memory. For detailed parameter adjustments, please refer to [GraphOptimizationBackend](../features/graph_optimization.md) for related configuration parameter descriptions
|
||||
|
||||
#### 2.2.6 Rejection Sampling
|
||||
#### 2.2.5 Rejection Sampling
|
||||
**Idea:**
|
||||
Rejection sampling is to generate samples from a proposal distribution that is easy to sample, avoiding explicit sorting to increase the sampling speed, which has a significant improvement on small-sized models.
|
||||
|
||||
|
@@ -19,22 +19,23 @@ The minimum number of GPUs required to deploy `ERNIE-4.5-21B-A3B` on the followi
|
||||
### 1.2 Install fastdeploy and prepare the model
|
||||
- Installation: For detail, please refer to [Fastdeploy Installation](../get_started/installation/README.md).
|
||||
|
||||
- Model Download,For detail, please refer to [Supported Models](../supported_models.md). **Please note that models with Paddle suffix need to be used for Fastdeploy**:
|
||||
- Model Download,For detail, please refer to [Supported Models](../supported_models.md).
|
||||
|
||||
## 2.How to Use
|
||||
### 2.1 Basic: Launching the Service
|
||||
Start the service by following command:
|
||||
```bash
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model baidu/ERNIE-4.5-21B-A3B-Paddle \
|
||||
--tensor-parallel-size 1 \
|
||||
--quantization wint4 \
|
||||
--max-model-len 32768 \
|
||||
--max-num-seqs 128
|
||||
--max-num-seqs 128 \
|
||||
--load_choices "default_v1"
|
||||
```
|
||||
- `--quantization`: indicates the quantization strategy used by the model. Different quantization strategies will result in different performance and accuracy of the model. It could be one of `wint8` / `wint4` / `block_wise_fp8`(Hopper is needed).
|
||||
- `--max-model-len`: Indicates the maximum number of tokens supported by the currently deployed service. The larger the value, the longer the context length the model can support, but the more GPU memory is occupied, which may affect the concurrency.
|
||||
- `--load_choices`: indicates the version of the loader. "default_v1" means enabling the v1 version of the loader, which has faster loading speed and less memory usage.
|
||||
|
||||
For more parameter meanings and default settings, see [FastDeploy Parameter Documentation](../parameters.md)。
|
||||
|
||||
@@ -42,17 +43,14 @@ For more parameter meanings and default settings, see [FastDeploy Parameter Docu
|
||||
#### 2.2.1 Correctly set parameters that match the application scenario
|
||||
Evaluate average input length, average output length, and maximum context length
|
||||
- Set max-model-len according to the maximum context length. For example, if the average input length is 1000 and the output length is 30000, then it is recommended to set it to 32768
|
||||
- **Enable the service management global block**
|
||||
|
||||
```
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
```
|
||||
|
||||
#### 2.2.2 Prefix Caching
|
||||
**Idea:** The core idea of Prefix Caching is to avoid repeated calculations by caching the intermediate calculation results of the input sequence (KV Cache), thereby speeding up the response speed of multiple requests with the same prefix. For details, refer to [prefix-cache](../features/prefix_caching.md)
|
||||
|
||||
**How to enable:**
|
||||
Add the following lines to the startup parameters, where `--enable-prefix-caching` enables prefix caching, and `--swap-space` enables CPU cache in addition to GPU cache. The size is GB and should be adjusted according to the actual situation of the machine. The recommended value is `(total machine memory - model size) * 20%`. If the service fails to start because other programs are occupying memory, try reducing the `--swap-space` value.
|
||||
Since version 2.2 (including the develop branch), Prefix Caching has been enabled by default.
|
||||
|
||||
For versions 2.1 and earlier, you need to enable it manually by adding following lines to the startup parameters, where `--enable-prefix-caching` enables prefix caching, and `--swap-space` enables CPU cache in addition to GPU cache. The size is GB and should be adjusted according to the actual situation of the machine. The recommended value is `(total machine memory - model size) * 20%`. If the service fails to start because other programs are occupying memory, try reducing the `--swap-space` value.
|
||||
```
|
||||
--enable-prefix-caching
|
||||
--swap-space 50
|
||||
@@ -61,7 +59,10 @@ Add the following lines to the startup parameters, where `--enable-prefix-cachin
|
||||
#### 2.2.3 Chunked Prefill
|
||||
**Idea:** This strategy is adopted to split the prefill stage request into small-scale sub-chunks, and execute them in batches mixed with the decode request. This can better balance the computation-intensive (Prefill) and memory-intensive (Decode) operations, optimize GPU resource utilization, reduce the computational workload and memory usage of a single Prefill, thereby reducing the peak memory usage and avoiding the problem of insufficient memory. For details, please refer to [Chunked Prefill](../features/chunked_prefill.md)
|
||||
|
||||
**How to enable:** Add the following lines to the startup parameters
|
||||
**How to enable:**
|
||||
Since version 2.2 (including the develop branch), Chunked Prefill has been enabled by default.
|
||||
|
||||
For versions 2.1 and earlier, you need to enable it manually by adding
|
||||
```
|
||||
--enable-chunked-prefill
|
||||
```
|
||||
@@ -77,7 +78,9 @@ Add the following lines to the startup parameters
|
||||
```
|
||||
Notes:
|
||||
1. MTP currently does not support simultaneous use with Prefix Caching, Chunked Prefill, and CUDAGraph.
|
||||
2. MTP currently does not support service management global blocks, i.e. do not run with `export ENABLE_V1_KVCACHE_SCHEDULER=1`
|
||||
- Use `export FD_DISABLE_CHUNKED_PREFILL=1` to disable Chunked Prefill.
|
||||
- When setting `speculative-config`, Prefix Caching will be automatically disabled.
|
||||
2. MTP currently does not support service management global blocks, When setting `speculative-config`, service management global blocks will be automatically disabled.
|
||||
3. MTP currently does not support rejection sampling, i.e. do not run with `export FD_SAMPLING_CLASS=rejection`
|
||||
|
||||
#### 2.2.5 CUDAGraph
|
||||
@@ -110,7 +113,6 @@ export FD_SAMPLING_CLASS=rejection
|
||||
# prefill
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
export INFERENCE_MSG_QUEUE_ID=1315
|
||||
export FLAGS_max_partition_size=2048
|
||||
export FD_ATTENTION_BACKEND=FLASH_ATTN
|
||||
export FD_LOG_DIR="prefill_log"
|
||||
|
||||
@@ -130,7 +132,6 @@ python -m fastdeploy.entrypoints.openai.api_server --model baidu/ERNIE-4.5-21B-A
|
||||
# decode
|
||||
export CUDA_VISIBLE_DEVICES=4,5,6,7
|
||||
export INFERENCE_MSG_QUEUE_ID=1215
|
||||
export FLAGS_max_partition_size=2048
|
||||
export FD_LOG_DIR="decode_log"
|
||||
|
||||
quant_type=block_wise_fp8
|
||||
|
@@ -16,22 +16,23 @@ The minimum number of GPUs required to deploy `ERNIE-4.5-300B-A47B` on the follo
|
||||
### 1.2 Install fastdeploy
|
||||
- Installation: For detail, please refer to [Fastdeploy Installation](../get_started/installation/README.md).
|
||||
|
||||
- Model Download,For detail, please refer to [Supported Models](../supported_models.md). **Please note that models with Paddle suffix need to be used for Fastdeploy**:
|
||||
- Model Download,For detail, please refer to [Supported Models](../supported_models.md).
|
||||
|
||||
## 2.How to Use
|
||||
### 2.1 Basic: Launching the Service
|
||||
Start the service by following command:
|
||||
```bash
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
|
||||
--tensor-parallel-size 8 \
|
||||
--quantization wint4 \
|
||||
--max-model-len 32768 \
|
||||
--max-num-seqs 128
|
||||
--max-num-seqs 128 \
|
||||
--load_choices "default_v1"
|
||||
```
|
||||
- `--quantization`: indicates the quantization strategy used by the model. Different quantization strategies will result in different performance and accuracy of the model. It could be one of `wint8` / `wint4` / `block_wise_fp8`(Hopper is needed).
|
||||
- `--max-model-len`: Indicates the maximum number of tokens supported by the currently deployed service. The larger the value, the longer the context length the model can support, but the more GPU memory is occupied, which may affect the concurrency.
|
||||
- `--load_choices`: indicates the version of the loader. "default_v1" means enabling the v1 version of the loader, which has faster loading speed and less memory usage.
|
||||
|
||||
For more parameter meanings and default settings, see [FastDeploy Parameter Documentation](../parameters.md)。
|
||||
|
||||
@@ -39,17 +40,14 @@ For more parameter meanings and default settings, see [FastDeploy Parameter Docu
|
||||
#### 2.2.1 Correctly set parameters that match the application scenario
|
||||
Evaluate average input length, average output length, and maximum context length
|
||||
- Set max-model-len according to the maximum context length. For example, if the average input length is 1000 and the output length is 30000, then it is recommended to set it to 32768
|
||||
- **Enable the service management global block**
|
||||
|
||||
```
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
```
|
||||
|
||||
#### 2.2.2 Prefix Caching
|
||||
**Idea:** The core idea of Prefix Caching is to avoid repeated calculations by caching the intermediate calculation results of the input sequence (KV Cache), thereby speeding up the response speed of multiple requests with the same prefix. For details, refer to [prefix-cache](../features/prefix_caching.md)
|
||||
|
||||
**How to enable:**
|
||||
Add the following lines to the startup parameters, where `--enable-prefix-caching` enables prefix caching, and `--swap-space` enables CPU cache in addition to GPU cache. The size is GB and should be adjusted according to the actual situation of the machine. The recommended value is `(total machine memory - model size) * 20%`. If the service fails to start because other programs are occupying memory, try reducing the `--swap-space` value.
|
||||
Since version 2.2 (including the develop branch), Prefix Caching has been enabled by default.
|
||||
|
||||
For versions 2.1 and earlier, you need to enable it manually by adding following lines to the startup parameters, where `--enable-prefix-caching` enables prefix caching, and `--swap-space` enables CPU cache in addition to GPU cache. The size is GB and should be adjusted according to the actual situation of the machine. The recommended value is `(total machine memory - model size) * 20%`. If the service fails to start because other programs are occupying memory, try reducing the `--swap-space` value.
|
||||
```
|
||||
--enable-prefix-caching
|
||||
--swap-space 50
|
||||
@@ -58,7 +56,10 @@ Add the following lines to the startup parameters, where `--enable-prefix-cachin
|
||||
#### 2.2.3 Chunked Prefill
|
||||
**Idea:** This strategy is adopted to split the prefill stage request into small-scale sub-chunks, and execute them in batches mixed with the decode request. This can better balance the computation-intensive (Prefill) and memory-intensive (Decode) operations, optimize GPU resource utilization, reduce the computational workload and memory usage of a single Prefill, thereby reducing the peak memory usage and avoiding the problem of insufficient memory. For details, please refer to [Chunked Prefill](../features/chunked_prefill.md)
|
||||
|
||||
**How to enable:** Add the following lines to the startup parameters
|
||||
**How to enable:**
|
||||
Since version 2.2 (including the develop branch), Chunked Prefill has been enabled by default.
|
||||
|
||||
For versions 2.1 and earlier, you need to enable it manually by adding
|
||||
```
|
||||
--enable-chunked-prefill
|
||||
```
|
||||
@@ -74,7 +75,9 @@ Add the following lines to the startup parameters
|
||||
```
|
||||
Notes:
|
||||
1. MTP currently does not support simultaneous use with Prefix Caching, Chunked Prefill, and CUDAGraph.
|
||||
2. MTP currently does not support service management global blocks, i.e. do not run with `export ENABLE_V1_KVCACHE_SCHEDULER=1`
|
||||
- Use `export FD_DISABLE_CHUNKED_PREFILL=1` to disable Chunked Prefill.
|
||||
- When setting `speculative-config`, Prefix Caching will be automatically disabled.
|
||||
2. MTP currently does not support service management global blocks, When setting `speculative-config`, service management global blocks will be automatically disabled.
|
||||
3. MTP currently does not support rejection sampling, i.e. do not run with `export FD_SAMPLING_CLASS=rejection`
|
||||
|
||||
#### 2.2.5 W4A8C8 Quantization
|
||||
@@ -87,6 +90,9 @@ Just specify the corresponding model name in the startup command, `baidu/ERNIE-4
|
||||
--model baidu/ERNIE-4.5-300B-A47B-W4A8C8-TP4-Paddle
|
||||
```
|
||||
|
||||
Note:
|
||||
- W4A8C8 quantized models are not supported when loaded via `--load_choices "default_v1"`.
|
||||
|
||||
#### 2.2.6 Rejection Sampling
|
||||
**Idea:**
|
||||
Rejection sampling is to generate samples from a proposal distribution that is easy to sample, avoiding explicit sorting to increase the sampling speed, which has a significant improvement on small-sized models.
|
||||
|
@@ -18,15 +18,10 @@ The minimum number of cards required for deployment on the following hardware is
|
||||
|
||||
Installation process reference documentation [FastDeploy GPU Install](../get_started/installation/nvidia_gpu.md)
|
||||
|
||||
> ⚠️ Precautions:
|
||||
> - FastDeploy only supports models in Paddle format – please ensure to download models with the `-Paddle` file extension.
|
||||
> - The model name will trigger an automatic download. If the model has already been downloaded, you can directly use the absolute path to the model's download location.
|
||||
|
||||
## 2.How to Use
|
||||
### 2.1 Basic: Launching the Service
|
||||
**Example 1:** Deploying a 32K Context Service on a Single RTX 4090 GPU
|
||||
```shell
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model baidu/ERNIE-4.5-VL-28B-A3B-Paddle \
|
||||
--port 8180 \
|
||||
@@ -38,14 +33,11 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--limit-mm-per-prompt '{"image": 100, "video": 100}' \
|
||||
--reasoning-parser ernie-45-vl \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--enable-chunked-prefill \
|
||||
--max-num-batched-tokens 384 \
|
||||
--quantization wint4 \
|
||||
--enable-mm
|
||||
--quantization wint4
|
||||
```
|
||||
**Example 2:** Deploying a 128K Context Service on Dual H800 GPUs
|
||||
```shell
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model baidu/ERNIE-4.5-VL-28B-A3B-Paddle \
|
||||
--port 8180 \
|
||||
@@ -57,14 +49,10 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--limit-mm-per-prompt '{"image": 100, "video": 100}' \
|
||||
--reasoning-parser ernie-45-vl \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--enable-chunked-prefill \
|
||||
--max-num-batched-tokens 384 \
|
||||
--quantization wint4 \
|
||||
--enable-mm
|
||||
--quantization wint4
|
||||
```
|
||||
|
||||
> ⚠️ For versions 2.1 and above, the new scheduler needs to be enabled via an environment variable `ENABLE_V1_KVCACHE_SCHEDULER=1`. Otherwise, some requests may be truncated before reaching the maximum length or return empty results.
|
||||
|
||||
An example is a set of configurations that can run stably while also delivering relatively good performance. If you have further requirements for precision or performance, please continue reading the content below.
|
||||
### 2.2 Advanced: How to Achieve Better Performance
|
||||
|
||||
@@ -92,8 +80,8 @@ An example is a set of configurations that can run stably while also delivering
|
||||
|
||||
#### 2.2.2 Chunked Prefill
|
||||
- **Parameters:** `--enable-chunked-prefill`
|
||||
- **Description:** Enabling `chunked prefill` can **reduce peak GPU memory usage** and **improve service throughput**.
|
||||
- **Other relevant configurations**:
|
||||
- **Description:** Enabling `chunked prefill` can reduce peak GPU memory usage and improve service throughput. Version 2.2 has **enabled by default**; for versions prior to 2.2, you need to enable it manually—refer to the best practices documentation for 2.1.
|
||||
- **Relevant configurations**:
|
||||
|
||||
`--max-num-batched-tokens`:Limit the maximum number of tokens per chunk, with a recommended setting of 384.
|
||||
|
||||
@@ -115,12 +103,7 @@ An example is a set of configurations that can run stably while also delivering
|
||||
- **Description:** Rejection sampling involves generating samples from a proposal distribution that is easy to sample from, thereby avoiding explicit sorting and achieving an effect of improving sampling speed, which can enhance inference performance.
|
||||
- **Recommendation:** This is a relatively aggressive optimization strategy that affects the results, and we are still conducting comprehensive validation of its impact. If you have high performance requirements and can accept potential compromises in results, you may consider enabling this strategy.
|
||||
|
||||
> **Attention Hyperparameter:**`FLAGS_max_partition_size=1024`
|
||||
- **Description:** The hyperparameters for the Append Attention (default) backend have been tested on commonly used datasets, and our results show that setting it to 1024 can significantly improve decoding speed, especially in long-text scenarios.
|
||||
- **Recommendation:** In the future, it will be modified to an automatic adjustment mechanism. If you have high performance requirements, you may consider enabling it.
|
||||
|
||||
## 3. FAQ
|
||||
**Note:** Deploying multimodal services requires adding parameters to the configuration `--enable-mm`.
|
||||
|
||||
### 3.1 Out of Memory
|
||||
If the service prompts "Out of Memory" during startup, please try the following solutions:
|
||||
|
@@ -15,15 +15,10 @@ The minimum number of cards required for deployment on the following hardware is
|
||||
|
||||
Installation process reference documentation [FastDeploy GPU Install](../get_started/installation/nvidia_gpu.md)
|
||||
|
||||
> ⚠️ Precautions:
|
||||
> - FastDeploy only supports models in Paddle format – please ensure to download models with the `-Paddle` file extension.
|
||||
> - The model name will trigger an automatic download. If the model has already been downloaded, you can directly use the absolute path to the model's download location.
|
||||
|
||||
## 2.How to Use
|
||||
### 2.1 Basic: Launching the Service
|
||||
**Example 1:** Deploying a 128K context service on 8x H800 GPUs.
|
||||
```shell
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model baidu/ERNIE-4.5-VL-424B-A47B-Paddle \
|
||||
--port 8180 \
|
||||
@@ -34,15 +29,11 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--max-num-seqs 16 \
|
||||
--limit-mm-per-prompt '{"image": 100, "video": 100}' \
|
||||
--reasoning-parser ernie-45-vl \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--enable-chunked-prefill \
|
||||
--gpu-memory-utilization 0.85 \
|
||||
--max-num-batched-tokens 384 \
|
||||
--quantization wint4 \
|
||||
--enable-mm
|
||||
--quantization wint4
|
||||
```
|
||||
|
||||
> ⚠️ For versions 2.1 and above, the new scheduler needs to be enabled via an environment variable `ENABLE_V1_KVCACHE_SCHEDULER=1`. Otherwise, some requests may be truncated before reaching the maximum length or return empty results.
|
||||
|
||||
An example is a set of configurations that can run stably while also delivering relatively good performance. If you have further requirements for precision or performance, please continue reading the content below.
|
||||
### 2.2 Advanced: How to Achieve Better Performance
|
||||
|
||||
@@ -70,8 +61,8 @@ An example is a set of configurations that can run stably while also delivering
|
||||
|
||||
#### 2.2.2 Chunked Prefill
|
||||
- **Parameters:** `--enable-chunked-prefill`
|
||||
- **Description:** Enabling `chunked prefill` can **reduce peak GPU memory usage** and **improve service throughput**.
|
||||
- **Other relevant configurations**:
|
||||
- **Description:** Enabling `chunked prefill` can reduce peak GPU memory usage and improve service throughput. Version 2.2 has **enabled by default**; for versions prior to 2.2, you need to enable it manually—refer to the best practices documentation for 2.1.
|
||||
- **Relevant configurations**:
|
||||
|
||||
`--max-num-batched-tokens`:Limit the maximum number of tokens per chunk, with a recommended setting of 384.
|
||||
|
||||
@@ -93,12 +84,7 @@ An example is a set of configurations that can run stably while also delivering
|
||||
- **Description:** Rejection sampling involves generating samples from a proposal distribution that is easy to sample from, thereby avoiding explicit sorting and achieving an effect of improving sampling speed, which can enhance inference performance.
|
||||
- **Recommendation:** This is a relatively aggressive optimization strategy that affects the results, and we are still conducting comprehensive validation of its impact. If you have high performance requirements and can accept potential compromises in results, you may consider enabling this strategy.
|
||||
|
||||
> **Attention Hyperparameter:**`FLAGS_max_partition_size=1024`
|
||||
- **Description:** The hyperparameters for the Append Attention (default) backend have been tested on commonly used datasets, and our results show that setting it to 1024 can significantly improve decoding speed, especially in long-text scenarios.
|
||||
- **Recommendation:** In the future, it will be modified to an automatic adjustment mechanism. If you have high performance requirements, you may consider enabling it.
|
||||
|
||||
## 3. FAQ
|
||||
**Note:** Deploying multimodal services requires adding parameters to the configuration `--enable-mm`.
|
||||
|
||||
### 3.1 Out of Memory
|
||||
If the service prompts "Out of Memory" during startup, please try the following solutions:
|
||||
|
151
docs/features/data_parallel_service.md
Normal file
@@ -0,0 +1,151 @@
|
||||
# Data Parallelism
|
||||
Under the MOE model, enabling Expert Parallelism (EP) combined with Data Parallelism (DP), where EP distributes expert workloads and DP enables parallel request processing.
|
||||
|
||||
## Data Distribution Strategy
|
||||
FastDeploy uses the splitwise scheduler to monitor the load status of each DP node and distribute incoming data accordingly.
|
||||
|
||||
The splitwise scheduler relies on Redis to store DP load status and distribute received data.
|
||||
|
||||
### Expert Parallelism + Hybrid Deployment
|
||||
FastDeploy provides the splitwise scheduler that monitors DP load status and schedules incoming data.
|
||||
The scheduling flow is shown below - users randomly request IP and port, obtain load status via Redis, and data is distributed to less-loaded DPs for inference.
|
||||

|
||||
|
||||
#### Offline Inference
|
||||
```python
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"你好,请问今天是星期",
|
||||
"请写6个以数字开头的成语",
|
||||
"写一个300字的小说大纲,内容是李白穿越到现代,最后成为公司文职人员的故事",
|
||||
"我要采访一位科幻作家,创建一个包含5个问题的列表"
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=128)
|
||||
|
||||
llm = LLM(
|
||||
model="ERNIE-4_5-300B-A47B-FP8-Paddle",
|
||||
tensor_parallel_size=1,
|
||||
data_parallel_size=8,
|
||||
max_model_len=8192,
|
||||
num_gpu_blocks_override=1024,
|
||||
engine_worker_queue_port="6077,6078,6079,6080,6081,6082,6083,6084",
|
||||
enable_expert_parallel=True,
|
||||
scheduler_name="splitwise",
|
||||
scheduler_host="127.0.0.1",
|
||||
scheduler_topic="test",
|
||||
scheduler_port=6379
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs.text
|
||||
print("generated_text: ", generated_text)
|
||||
print("\n")
|
||||
```
|
||||
|
||||
#### Online Inference
|
||||
```shell
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ERNIE-4_5-300B-A47B-FP8-Paddle \
|
||||
--port 8184 --metrics-port 8185 \
|
||||
--engine-worker-queue-port "6077,6078,6079,6080,6081,6082,6083,6084" \
|
||||
--data-parallel-size 8 --tensor-parallel-size 1\
|
||||
--enable-expert-parallel \
|
||||
--scheduler-name "splitwise" \
|
||||
--scheduler-host "127.0.0.1" \
|
||||
--scheduler-port 6379 \
|
||||
--scheduler-topic "test" \
|
||||
--scheduler-ttl 9000
|
||||
```
|
||||
|
||||
### User-Managed Scheduling
|
||||
FastDeploy provides multi_api_server, allowing users to launch multiple API servers and manually select DPs for requests. In this case, users can add their own load balancing models for scheduling. (Currently only supports online inference)
|
||||
|
||||
#### Online Inference
|
||||

|
||||
|
||||
```shell
|
||||
export FD_ENABLE_MULTI_API_SERVER=1
|
||||
python -m fastdeploy.entrypoints.openai.multi_api_server \
|
||||
--ports "1811,1822,1833,1844,1855,1866,1877,1888" \
|
||||
--num-servers 8 \
|
||||
--metrics-ports "3101,3201,3301,3401,3501,3601,3701,3801" \
|
||||
--args --model ERNIE-4_5-300B-A47B-FP8-Paddle \
|
||||
--engine-worker-queue-port "25611,25621,25631,25641,25651,25661,25671,25681" \
|
||||
--tensor-parallel-size 1 \
|
||||
--data-parallel-size 8 \
|
||||
--max-model-len 12288 \
|
||||
--max-num-seqs 64 \
|
||||
--num-gpu-blocks-override 256 \
|
||||
--enable-expert-parallel
|
||||
```
|
||||
|
||||
### Parameter Description
|
||||
- num-servers: Number of API servers to launch
|
||||
- ports: Ports for API servers
|
||||
- args: Arguments for API servers
|
||||
|
||||
### Data Parallelism + Disaggregated Deployment
|
||||
Refer to [Disaggregated Deployment](disaggregated.md#multi-machine-disaggregated-deployment)
|
||||
|
||||
#### Online Inference
|
||||
For multi-machine deployment, ensure network cards support RDMA and all cluster nodes are interconnected.
|
||||
|
||||
**Note**:
|
||||
* `KVCACHE_RDMA_NICS` specifies RDMA network cards for the current machine, multiple cards should be separated by commas.
|
||||
* The repository provides an automatic RDMA network card detection script `bash scripts/get_rdma_nics.sh <device>`, where <device> can be `cpu` or `gpu`.
|
||||
|
||||
|
||||
**Prefill Instance**
|
||||
```bash
|
||||
export FD_LOG_DIR="log_prefill"
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
echo "set RDMA NICS"
|
||||
export $(bash scripts/get_rdma_nics.sh gpu)
|
||||
echo "KVCACHE_RDMA_NICS ${KVCACHE_RDMA_NICS}"
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ERNIE-4_5-300B-A47B-FP8-Paddle \
|
||||
--port 8180 --metrics-port 8181 \
|
||||
--engine-worker-queue-port "25611,25621,25631,25641,25651,25661,25671,25681" \
|
||||
--cache-queue-port 8183 \
|
||||
--tensor-parallel-size 1 \
|
||||
--data-parallel-size 4 \
|
||||
--enable-expert-parallel \
|
||||
--cache-transfer-protocol "rdma,ipc" \
|
||||
--rdma-comm-ports "7671,7672,7673,7674,7675,7676,7677,7678" \
|
||||
--pd-comm-port "2334" \
|
||||
--splitwise-role "prefill" \
|
||||
--scheduler-name "splitwise" \
|
||||
--scheduler-host "127.0.0.1" \
|
||||
--scheduler-port 6379 \
|
||||
--scheduler-topic "test" \
|
||||
--scheduler-ttl 9000
|
||||
```
|
||||
|
||||
**Decode Instance**
|
||||
```bash
|
||||
export FD_LOG_DIR="log_decode"
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
echo "set RDMA NICS"
|
||||
export $(bash scripts/get_rdma_nics.sh gpu)
|
||||
echo "KVCACHE_RDMA_NICS ${KVCACHE_RDMA_NICS}"
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ERNIE-4_5-300B-A47B-FP8-Paddle \
|
||||
--port 8184 --metrics-port 8185 \
|
||||
--engine-worker-queue-port "25611,25621,25631,25641,25651,25661,25671,25681" \
|
||||
--cache-queue-port 8187 \
|
||||
--tensor-parallel-size 1 \
|
||||
--data-parallel-size 4 \
|
||||
--enable-expert-parallel \
|
||||
--scheduler-name "splitwise" \
|
||||
--cache-transfer-protocol "rdma,ipc" \
|
||||
--rdma-comm-ports "7671,7672,7673,7674,7675,7676,7677,7678" \
|
||||
--pd-comm-port "2334" \
|
||||
--scheduler-host "127.0.0.1" \
|
||||
--scheduler-port 6379 \
|
||||
--scheduler-ttl 9000
|
||||
--scheduler-topic "test" \
|
||||
--splitwise-role "decode"
|
||||
```
|
@@ -72,6 +72,11 @@ Refer to the example code `offline_disaggregated_demo.py` in the `fastdeploy/dem
|
||||
### Multi-machine Disaggregated Deployment
|
||||
|
||||
#### Prerequisite: Redis
|
||||
|
||||
> **⚠️ NOTE**
|
||||
> **Redis requirement: version 6.2.0 or higher**
|
||||
> Versions below this may not support the required commands.
|
||||
>
|
||||
* Installation via `conda`
|
||||
|
||||
```bash
|
||||
@@ -103,14 +108,17 @@ sudo systemctl start redis
|
||||
For multi-machine deployment, confirm that the NIC supports RDMA and that all nodes in the cluster have network connectivity.
|
||||
|
||||
**Note**:
|
||||
* `KVCACHE_RDMA_NICS` specifies the RDMA NICs of the current machine, with multiple NICs separated by commas.
|
||||
* `KVCACHE_RDMA_NICS` specifies RDMA network cards for the current machine, multiple cards should be separated by commas.
|
||||
* The repository provides an automatic RDMA network card detection script `bash scripts/get_rdma_nics.sh <device>`, where <device> can be `cpu` or `gpu`.
|
||||
|
||||
**Prefill Instance**
|
||||
|
||||
```bash
|
||||
export FD_LOG_DIR="log_prefill"
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
export KVCACHE_RDMA_NICS="mlx5_2,mlx5_3,mlx5_4,mlx5_5"
|
||||
echo "set RDMA NICS"
|
||||
export $(bash scripts/get_rdma_nics.sh gpu)
|
||||
echo "KVCACHE_RDMA_NICS ${KVCACHE_RDMA_NICS}"
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ERNIE-4.5-300B-A47B-BF16 \
|
||||
--port 8180 --metrics-port 8181 \
|
||||
@@ -133,7 +141,9 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
```bash
|
||||
export FD_LOG_DIR="log_decode"
|
||||
export CUDA_VISIBLE_DEVICES=4,5,6,7
|
||||
export KVCACHE_RDMA_NICS="mlx5_2,mlx5_3,mlx5_4,mlx5_5"
|
||||
echo "set RDMA NICS"
|
||||
export $(bash scripts/get_rdma_nics.sh gpu)
|
||||
echo "KVCACHE_RDMA_NICS ${KVCACHE_RDMA_NICS}"
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ERNIE-4.5-300B-A47B-BF16 \
|
||||
--port 8184 --metrics-port 8185 \
|
||||
|
@@ -44,7 +44,7 @@ CudaGrpah can be enabled by setting `--use-cudagraph` or `--graph-optimization-c
|
||||
The `graph_opt_level` parameter within `--graph-optimization-config` is used to configure the graph optimization level, with the following available options:
|
||||
+ `0`: Use Dynamic compute graph, default to 0
|
||||
+ `1`: Use Static compute graph, during the initialization phase, Paddle API will be used to convert the dynamic image into a static image
|
||||
+ `2`: Base on Static compute graph, use the complier(CINN, Compiler Infrastructure for Neural Networks) of Paddle to compile and optimize
|
||||
+ `2`: Base on Static compute graph, use the compiler(CINN, Compiler Infrastructure for Neural Networks) of Paddle to compile and optimize
|
||||
|
||||
In general, static graphs have lower Kernel Launch overhead than dynamic graphs, and it is recommended to use static graphs.
|
||||
For adapted models, FastDeploy's CudaGraph *can support both dynamic and static graphs* simultaneously.
|
||||
|
BIN
docs/features/images/no_scheduler_img.png
Normal file
After Width: | Height: | Size: 118 KiB |
BIN
docs/features/images/plas_inference_union.png
Normal file
After Width: | Height: | Size: 34 KiB |
BIN
docs/features/images/plas_training_distill.png
Normal file
After Width: | Height: | Size: 80 KiB |
BIN
docs/features/images/scheduler_img.png
Normal file
After Width: | Height: | Size: 105 KiB |
@@ -1,31 +0,0 @@
|
||||
# moba_sparse_attention
|
||||
|
||||
## Introduction
|
||||
|
||||
We propose Lite MoBA and improve it based on MoBA. Specifically, we still draw on the MoE structure to divide KV into multiple blocks, introduce a learnable MLP layer to adaptively select important blocks. We use Full Attention's 1D Max Pooling Attention Map as Ground Truth. Then, we employ KLDivLoss to distill and train the MLP layer weights. Lite MoBA can be directly applied to post - training, where only the weights of the MLP are learnable and the weights of the original model remain unchanged.
|
||||
|
||||
Compared to NSA or MoBA, our Lite MoBA is more scalable and pluggable, without the need to change traditional attention architectures or interfere with model weight training in the Pre - training and Post - training stages. It only requires a small amount of training on the MLP layer in the final stage of the model to achieve almost lossless accuracy. Since MoBA updates the weights of the entire model, even when Full Attention is automatically invoked for inputs shorter than BlockSize x BlockNum, it still cannot avoid the impact of model updates on the model's effectiveness in text processing. In contrast, our pluggable Lite MoBA can achieve Full Attention that is truly equivalent to that of the original model in short text scenarios.
|
||||
|
||||
Compared with MoBA, in terms of effectiveness, its use of Average Pooling to represent inter - block relationships appears relatively limited and has poor handling of outlier representations. Our ablation experiments also demonstrated that the effectiveness of Average Pooling is inferior to that of the learnable MLP. In terms of training performance, since only the MLP weights need to be updated and the model weights do not need to be updated, a large amount of video memory will be saved during training (which needs to be tested). In terms of inference performance, when the input length is 128K, Block Size = 1024, and Block Num = 16, the performance is improved by 322% compared to Flash Attention 3.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
export FD_ATTENTION_BACKEND="MOBA_ATTN"
|
||||
|
||||
python -m fastdeploy.entrypoints.openai.api_server
|
||||
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
|
||||
--port 8188 \
|
||||
--tensor-parallel-size 4 \
|
||||
--quantization wint4 \
|
||||
--enable-chunked-prefill \
|
||||
--max-num-batched-tokens 8192 \
|
||||
--max-model-len 131072 \
|
||||
--max-num-seqs 32 \
|
||||
--moba-attention-config '{"moba_encoder_top_k_left": 60, "moba_encoder_top_k_right": 80, "moba_decoder_top_k_left": 100, "moba_decoder_top_k_right": 120}'
|
||||
```
|
||||
## Environmental Variables Description
|
||||
|
||||
* Setting `FD_ATTENTION_BACKEND="MOBA_ATTN"` enables MOBA sparse attention.
|
||||
* `moba_encoder_top_k_left=60, moba_encoder_top_k_right=80` indicates that the range of top - k is between 80 and 100 when the encoder is sparse.
|
||||
* `moba_decoder_top_k_left=100, moba_decoder_top_k_right=100` indicates that the range of top - k is between 120 and 140 when the decoder is sparse.
|
219
docs/features/plas_attention.md
Normal file
@@ -0,0 +1,219 @@
|
||||
# PLAS
|
||||
|
||||
## Introduction
|
||||
|
||||
We propose **PLAS (Pluggable Lightweight Attention for Sparsity)**, an improvement over MoBA. Specifically, we adopt an MoE-inspired structure that partitions KV into multiple blocks and introduces a learnable MLP layer to adaptively select important blocks. PLAS can be directly applied during post-training, where only the MLP weights are learnable, and the original model weights remain unchanged.
|
||||
|
||||
Compared to NSA/MoBA, our PLAS offers greater scalability and pluggability. It does not require modifying the traditional attention architecture or interfering with model weight training during pre-training or post-training. Only a small amount of training for the MLP layer is needed at the final stage to achieve nearly lossless accuracy. Since NSA/MoBA updates the entire model weights, it inevitably affects performance on short texts—even though it automatically switches to full attention when the input length is shorter than BlockSize × Top-K. In contrast, our PLAS can achieve truly equivalent full attention to the original model in short-text scenarios.
|
||||
|
||||
In terms of training efficiency, the training cost is very low because only the MLP weight needs to be updated. For inference performance, when the input length is 128K, Block Size = 128, and Top-K = 55, PLAS achieves a **386% speedup** compared to Flash Attention 3.
|
||||
|
||||
## Method
|
||||
|
||||
### Training
|
||||
|
||||
Following the approaches of NSA and MoBA, we partition the KV into multiple blocks. During both the prefill and decode stages, instead of performing attention computation over all KV, we dynamically select the top-K blocks with the highest attention scores for each query token, thereby enabling efficient sparse attention computation.
|
||||
|
||||
<div align="center">
|
||||
<img src="images/plas_training_distill.png" alt="Attention Gate Module" width="60%">
|
||||
</div>
|
||||
|
||||
* **Attention Gate Module**: As illustrated in the figure above, to estimate the importance of each block with low computational overhead, we design a lightweight attention gate module. This module first compresses each K block via a MLP layer to generate a representative low-dimensional representation: $K_c^T=W_{kp}K^T$, where $W_{kp}$ denotes the MLP layer weights. Compared to directly applying mean pooling, the learnable MLP can more effectively capture semantic relationships and importance distributions among different tokens, thereby providing a refined representation of each block. After obtaining the compressed representation $K_c$, the importance of each query token with respect to each block is estimated via: $Softmax(Q\cdot K_c^T)$. To enhance the discriminative ability of the MLP layer, we use the full attention result after 1D max pooling $1DMaxPooling(Softmax(Q \cdot K^T))$ as the ground truth. By minimizing the distribution divergence between the two, the MLP layer is guided to learn feature representations that better align with the true attention distribution.
|
||||
* **Training Data**: Benefiting from the efficiency of both the model architecture and the training paradigm, our approach achieves near-lossless precision with only 1B tokens used for training. The training data is sourced from an internally constructed mixed corpus containing both long and short texts, thereby enhancing the module’s adaptability to varying sequence lengths.
|
||||
* **Other**: We observe that the final decode layer has a significant impact on the overall model accuracy. Therefore, during training, we exclude this layer from sparse attention computation and revert to full attention for this layer during inference.
|
||||
|
||||
### Inference
|
||||
|
||||
During sparse attention computation, each query token may dynamically select different KV blocks, leading to highly irregular memory access patterns in HBM. It is feasible to simply process each query token separately, but it will lead to excessively fine-grained computing, which cannot make full use of the tensor core, thus significantly reducing the GPU computing efficiency.
|
||||
|
||||
<div align="center">
|
||||
<img src="images/plas_inference_union.png" alt="Token/Head Union" width="60%">
|
||||
</div>
|
||||
|
||||
To optimize performance in both the prefill and decode stages, we design a special joint strategy to adapt to their respective characteristics:
|
||||
|
||||
* **Prefill Toke Union**: We observe that adjacent query tokens tend to select similar key blocks. Leveraging this locality, we take the union of the key blocks selected by consecutive 128 query tokens and jointly compute sparse attention for these tokens.
|
||||
* **Decode Head Union**: Given the widespread adoption of GQA in modern models, we find that different heads within the same group often select overlapping key blocks. Thus, we combine the key blocks selected by all query heads within a group into a unified set and jointly calculate sparse attention. This way also reduces memory access overhead and further improves decoding efficiency.
|
||||
* **Top-K Selection**: Conventional top-k algorithms based on sorting or direct calls to the cub library introduce significant runtime overhead. To mitigate this, we implemented an approximate top-k selection algorithm using binary search, which significantly reduces latency while maintaining accuracy, ultimately achieving significantly improved performance.
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Experiments
|
||||
|
||||
We evaluated the precision of full attention and sparse attention on LongBenchV2 and Ruler (with context lengths of 32K, 64K, and 128K).
|
||||
|
||||
<table style="border-collapse: collapse; width: 100%;">
|
||||
<tr>
|
||||
<td rowspan="4" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>Model</strong>
|
||||
</td>
|
||||
<td colspan="8" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>Precision</strong>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="4" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>FullAttention</strong>
|
||||
</td>
|
||||
<td colspan="4" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>SparseAttention</strong>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td rowspan="2" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>LongBenchV2</strong>
|
||||
</td>
|
||||
<td colspan="3" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>Ruler</strong>
|
||||
</td>
|
||||
<td rowspan="2" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>LongBenchV2</strong>
|
||||
</td>
|
||||
<td colspan="3" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>Ruler</strong>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>32K</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>64K</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>128K</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>32K</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>64K</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>128K</strong>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>ERNIE-4.5-21B-A3B</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">31.48</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">76.74</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">56.40</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">25.48</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">31.45</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">75.93</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">55.38</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">25.05</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>ERNIE-4.5-300B-A47B</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">41.02</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">94.70</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">83.56</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">58.18</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">41.05</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">94.50</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">82.32</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">57.85</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### Performance
|
||||
|
||||
We selected a subset (longbook_sum_eng) from InfiniteBench as the performance evaluation dataset. For inputs exceeding 128K in length, we truncate the sequence by keeping the first 64K and the last 64K tokens.
|
||||
|
||||
<table style="border-collapse: collapse; width: 100%;">
|
||||
<tr>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;"></td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;"></td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;"><strong>QPS</strong></td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;"><strong>Decode Speed (token/s)</strong></td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;"><strong>Time to First token(s)</strong></td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;"><strong>Time per Ouput Token(ms)</strong></td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;"><strong>End-to-End Latency(s)</strong></td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;"><strong>Mean Input<br>Length</strong></td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;"><strong>Mean Output Length</strong></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td rowspan="2" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>ERNIE-4.5-21B-A3B</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>FullAttention</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">0.101</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">13.32</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">8.082</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">87.05</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">61.400</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">113182.32</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">627.76</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>SparseAttention</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">0.150(+48%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">18.12(+36%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">5.466(-48%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">66.35(-31%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">42.157(-46%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">113182.32</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">590.23</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td rowspan="2" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>ERNIE-4.5-300B-A47B</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>FullAttention</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">0.066</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">5.07</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">13.812</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">206.70</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">164.704</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">113182.32</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">725.97</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>SparseAttention</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">0.081(+23%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">6.75(+33%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">10.584(-30%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">154.84(-34%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">132.745(-24%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">113182.32</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">748.25</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
export FD_ATTENTION_BACKEND="PLAS_ATTN"
|
||||
|
||||
python -m fastdeploy.entrypoints.openai.api_server
|
||||
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
|
||||
--port 8188 \
|
||||
--tensor-parallel-size 4 \
|
||||
--quantization wint4 \
|
||||
--enable-chunked-prefill \
|
||||
--max-num-batched-tokens 8192 \
|
||||
--max-model-len 131072 \
|
||||
--max-num-seqs 32 \
|
||||
--plas-attention-config '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}'
|
||||
```
|
||||
|
||||
**Note**: If sparse attention is enabled, the system will automatically load the MLP weights from `plas_attention_mlp_weight.safetensors` in the weight directory. If the MLP weight file is not found, mean pooling will be applied to the key representations.
|
||||
|
||||
**Parameter Description:**
|
||||
|
||||
* Setting `FD_ATTENTION_BACKEND="PLAS_ATTN"` enables PLAS sparse attention.
|
||||
* `plas_encoder_top_k_left=50, plas_encoder_top_k_right=60` indicates that the range of top-k is between 50 and 60 when the encoder is sparse.
|
||||
* `plas_decoder_top_k_left=100, plas_decoder_top_k_right=120` indicates that the range of top-k is between 100 and 120 when the decoder is sparse.
|
@@ -18,13 +18,6 @@ This project implements an efficient **Speculative Decoding** inference framewor
|
||||
- ⏳ Coming Soon: Support Chunk-prefill
|
||||
- ⏳ Coming Soon: Multi-layer MTP Layer
|
||||
|
||||
- **Decoding with Hybrid MTP and Ngram Methods(Hybrid-MTP-with-Ngram)**
|
||||
|
||||
- Overview: A hybrid method combining MTP and Ngram. First, MTP generates N draft tokens, then Ngram matching is used to supplement additional draft tokens.
|
||||
|
||||
- Use Cases: Suitable when higher draft token coverage is required, leveraging both MTP’s generation capability and the efficiency of Ngram matching.
|
||||
|
||||
|
||||
---
|
||||
|
||||
### Coming Soon
|
||||
@@ -139,13 +132,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--scheduler-password "scheduler_mtp" \
|
||||
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${path_to_mtp_model}"}' &
|
||||
```
|
||||
## Decoding with Hybrid MTP and Ngram Methods
|
||||
|
||||
When starting the service, you only need to modify the --speculative-config option.
|
||||
For example, use MTP to generate two draft tokens, and then append three additional draft tokens from Ngram matching:
|
||||
```
|
||||
--speculative-config '{"method": "mtp", "num_model_steps": 2, "mtp_strategy": "with_ngram", "num_speculative_tokens": 5, "model": "'$model_path'/mtp"}'
|
||||
```
|
||||
## 🧠 Using Ngram-Based Decoding
|
||||
This method uses an n-gram sliding window to match the prompt and generated tokens to predict draft tokens. It is particularly effective in scenarios with high input-output overlap (e.g., code completion, document search).
|
||||
|
||||
|
@@ -330,3 +330,65 @@ ParsedChatCompletionMessage[Info](content='{"addr": "No.1 Century Avenue, Pudong
|
||||
Address: No.1 Century Avenue, Pudong New Area, Shanghai
|
||||
Height: 468
|
||||
```
|
||||
|
||||
### Offline Inference
|
||||
|
||||
Offline inference allows restricting the model's output format by pre-specified constraints. In `FastDeploy`, constraints can be specified through the `GuidedDecodingParams` class in `SamplingParams`. `GuidedDecodingParams` supports the following constraint types, with usage similar to online inference:
|
||||
|
||||
```python
|
||||
json: Optional[Union[str, dict]] = None
|
||||
regex: Optional[str] = None
|
||||
choice: Optional[List[str]] = None
|
||||
grammar: Optional[str] = None
|
||||
json_object: Optional[bool] = None
|
||||
structural_tag: Optional[str] = None
|
||||
```
|
||||
|
||||
The following example demonstrates how to use offline inference to generate a structured json:
|
||||
|
||||
```python
|
||||
from fastdeploy import LLM, SamplingParams
|
||||
from fastdeploy.engine.sampling_params import GuidedDecodingParams
|
||||
from pydantic import BaseModel
|
||||
from enum import Enum
|
||||
|
||||
class BookType(str, Enum):
|
||||
romance = "Romance"
|
||||
historical = "Historical"
|
||||
adventure = "Adventure"
|
||||
mystery = "Mystery"
|
||||
dystopian = "Dystopian"
|
||||
|
||||
class BookDescription(BaseModel):
|
||||
author: str
|
||||
title: str
|
||||
genre: BookType
|
||||
|
||||
# Constrained decoding parameters
|
||||
guided_decoding_params = GuidedDecodingParams(json=BookDescription.model_json_schema())
|
||||
|
||||
# Sampling parameters
|
||||
sampling_params = SamplingParams(
|
||||
top_p=0.95,
|
||||
max_tokens=6400,
|
||||
guided_decoding=guided_decoding_params,
|
||||
)
|
||||
|
||||
# Load model
|
||||
llm = LLM(model="ERNIE-4.5-0.3B", tensor_parallel_size=1, max_model_len=8192, guided_decoding_backend="auto")
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts="Generate a JSON describing a literary work, including author, title and book type.",
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
# Output results
|
||||
for output in outputs:
|
||||
print(output.outputs.text)
|
||||
```
|
||||
|
||||
Output:
|
||||
|
||||
```
|
||||
{"author": "George Orwell", "title": "1984", "genre": "Dystopian"}
|
||||
```
|
||||
|
@@ -62,7 +62,7 @@ python -m pip install paddlepaddle==3.1.1 -i https://www.paddlepaddle.org.cn/pac
|
||||
python -m pip install paddle-custom-gcu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/gcu/
|
||||
# For source compilation, refer to: https://github.com/PaddlePaddle/PaddleCustomDevice/blob/develop/backends/gcu/README_cn.md
|
||||
```
|
||||
For latest paddle verion on iluvatar. Refer to [PaddlePaddle Installation](https://www.paddlepaddle.org.cn/)
|
||||
For latest paddle version on iluvatar. Refer to [PaddlePaddle Installation](https://www.paddlepaddle.org.cn/)
|
||||
|
||||
6. Install FastDeploy and dependencies
|
||||
```bash
|
||||
|
@@ -25,9 +25,9 @@ Verified platform:
|
||||
```bash
|
||||
mkdir Work
|
||||
cd Work
|
||||
docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.1.0
|
||||
docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.2.0
|
||||
docker run --name fastdeploy-xpu --net=host -itd --privileged -v $PWD:/Work -w /Work \
|
||||
ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.1.0 \
|
||||
ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-xpu:2.2.0 \
|
||||
/bin/bash
|
||||
docker exec -it fastdeploy-xpu /bin/bash
|
||||
```
|
||||
@@ -37,7 +37,7 @@ docker exec -it fastdeploy-xpu /bin/bash
|
||||
### Install PaddlePaddle
|
||||
|
||||
```bash
|
||||
python -m pip install paddlepaddle-xpu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/
|
||||
python -m pip install paddlepaddle-xpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/
|
||||
```
|
||||
|
||||
Alternatively, you can install the latest version of PaddlePaddle (Not recommended)
|
||||
@@ -49,7 +49,7 @@ python -m pip install --pre paddlepaddle-xpu -i https://www.paddlepaddle.org.cn/
|
||||
### Install FastDeploy (**Do NOT install via PyPI source**)
|
||||
|
||||
```bash
|
||||
python -m pip install fastdeploy-xpu==2.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-xpu-p800/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
python -m pip install fastdeploy-xpu==2.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-xpu-p800/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
```
|
||||
|
||||
Alternatively, you can install the latest version of FastDeploy (Not recommended)
|
||||
@@ -63,7 +63,7 @@ python -m pip install --pre fastdeploy-xpu -i https://www.paddlepaddle.org.cn/pa
|
||||
### Install PaddlePaddle
|
||||
|
||||
```bash
|
||||
python -m pip install paddlepaddle-xpu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/
|
||||
python -m pip install paddlepaddle-xpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/
|
||||
```
|
||||
|
||||
Alternatively, you can install the latest version of PaddlePaddle (Not recommended)
|
||||
|
@@ -13,14 +13,14 @@ The following installation methods are available when your environment meets the
|
||||
**Notice**: The pre-built image only supports SM80/90 GPU(e.g. H800/A800),if you are deploying on SM86/89GPU(L40/4090/L20), please reinstall ```fastdpeloy-gpu``` after you create the container.
|
||||
|
||||
```shell
|
||||
docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:2.1.0
|
||||
docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:2.2.0
|
||||
```
|
||||
|
||||
## 2. Pre-built Pip Installation
|
||||
|
||||
First install paddlepaddle-gpu. For detailed instructions, refer to [PaddlePaddle Installation](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/develop/install/pip/linux-pip_en.html)
|
||||
```shell
|
||||
python -m pip install paddlepaddle-gpu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
```
|
||||
|
||||
Then install fastdeploy. **Do not install from PyPI**. Use the following methods instead:
|
||||
@@ -58,7 +58,7 @@ docker build -f dockerfiles/Dockerfile.gpu -t fastdeploy:gpu .
|
||||
|
||||
First install paddlepaddle-gpu. For detailed instructions, refer to [PaddlePaddle Installation](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/develop/install/pip/linux-pip_en.html)
|
||||
```shell
|
||||
python -m pip install paddlepaddle-gpu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
python -m pip install paddlepaddle-gpu==3.2.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
|
||||
```
|
||||
|
||||
Then clone the source code and build:
|
||||
|
99
docs/get_started/quick_start_qwen.md
Normal file
@@ -0,0 +1,99 @@
|
||||
# Deploy QWEN3-0.6b in 10 Minutes
|
||||
|
||||
Before deployment, ensure your environment meets the following requirements:
|
||||
|
||||
- GPU Driver ≥ 535
|
||||
- CUDA ≥ 12.3
|
||||
- cuDNN ≥ 9.5
|
||||
- Linux X86_64
|
||||
- Python ≥ 3.10
|
||||
|
||||
This guide uses the lightweight QWEN3-0.6b model for demonstration, which can be deployed on most hardware configurations. Docker deployment is recommended.
|
||||
|
||||
For more information about how to install FastDeploy, refer to the [installation document](installation/README.md).
|
||||
|
||||
## 1. Launch Service
|
||||
After installing FastDeploy, execute the following command in the terminal to start the service. For the configuration method of the startup command, refer to [Parameter Description](../parameters.md)
|
||||
|
||||
> ⚠️ **Note:**
|
||||
> When using HuggingFace models (torch format), you need to enable `--load_choices "default_v1"`.
|
||||
|
||||
```
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model Qwen/QWEN3-0.6b \
|
||||
--port 8180 \
|
||||
--metrics-port 8181 \
|
||||
--engine-worker-queue-port 8182 \
|
||||
--max-model-len 32768 \
|
||||
--max-num-seqs 32 \
|
||||
--load_choices "default_v1"
|
||||
```
|
||||
|
||||
> 💡 Note: In the path specified by ```--model```, if the subdirectory corresponding to the path does not exist in the current directory, it will try to query whether AIStudio has a preset model based on the specified model name (such as ```Qwen/QWEN3-0.6b```). If it exists, it will automatically start downloading. The default download path is: ```~/xx```. For instructions and configuration on automatic model download, see [Model Download](../supported_models.md).
|
||||
```--max-model-len``` indicates the maximum number of tokens supported by the currently deployed service.
|
||||
```--max-num-seqs``` indicates the maximum number of concurrent processing supported by the currently deployed service.
|
||||
|
||||
**Related Documents**
|
||||
- [Service Deployment](../online_serving/README.md)
|
||||
- [Service Monitoring](../online_serving/metrics.md)
|
||||
|
||||
## 2. Request the Service
|
||||
After starting the service, the following output indicates successful initialization:
|
||||
|
||||
```shell
|
||||
api_server.py[line:91] Launching metrics service at http://0.0.0.0:8181/metrics
|
||||
api_server.py[line:94] Launching chat completion service at http://0.0.0.0:8180/v1/chat/completions
|
||||
api_server.py[line:97] Launching completion service at http://0.0.0.0:8180/v1/completions
|
||||
INFO: Started server process [13909]
|
||||
INFO: Waiting for application startup.
|
||||
INFO: Application startup complete.
|
||||
INFO: Uvicorn running on http://0.0.0.0:8180 (Press CTRL+C to quit)
|
||||
```
|
||||
|
||||
### Health Check
|
||||
|
||||
Verify service status (HTTP 200 indicates success):
|
||||
|
||||
```shell
|
||||
curl -i http://0.0.0.0:8180/health
|
||||
```
|
||||
|
||||
### cURL Request
|
||||
|
||||
Send requests to the service with the following command:
|
||||
|
||||
```shell
|
||||
curl -X POST "http://0.0.0.0:1822/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Write me a poem about large language model."}
|
||||
],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
### Python Client (OpenAI-compatible API)
|
||||
|
||||
FastDeploy's API is OpenAI-compatible. You can also use Python for requests:
|
||||
|
||||
```python
|
||||
import openai
|
||||
host = "0.0.0.0"
|
||||
port = "8180"
|
||||
client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="null",
|
||||
messages=[
|
||||
{"role": "system", "content": "I'm a helpful AI assistant."},
|
||||
{"role": "user", "content": "Write me a poem about large language model."},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta:
|
||||
print(chunk.choices[0].delta.content, end='')
|
||||
print('\n')
|
||||
```
|
@@ -11,15 +11,39 @@
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | Data Type | PD Disaggregation | Chunked Prefill | Prefix Caching | MTP | CUDA Graph | Maximum Context Length |
|
||||
| Model | Data Type |[PD Disaggregation](./features/disaggregated.md) | [Chunked Prefill](./features/chunked_prefill.md) | [Prefix Caching](./features/prefix_caching.md) | [MTP](./features/speculative_decoding.md) | [CUDA Graph](./features/graph_optimization.md) | Maximum Context Length |
|
||||
|:--- | :------- | :---------- | :-------- | :-------- | :----- | :----- | :----- |
|
||||
|ERNIE-4.5-300B-A47B | BF16/WINT4/WINT8/W4A8C8/WINT2/FP8 | ✅| ✅ | ✅|✅| WIP |128K |
|
||||
|ERNIE-4.5-300B-A47B-Base| BF16/WINT4/WINT8 | ✅| ✅ | ✅|❌| WIP | 128K |
|
||||
|ERNIE-4.5-VL-424B-A47B | BF16/WINT4/WINT8 | WIP | ✅ | WIP | ❌ | WIP |128K |
|
||||
|ERNIE-4.5-VL-28B-A3B | BF16/WINT4/WINT8 | ❌ | ✅ | WIP | ❌ | WIP |128K |
|
||||
|ERNIE-4.5-21B-A3B | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | ✅ | ✅|128K |
|
||||
|ERNIE-4.5-21B-A3B-Base | BF16/WINT4/WINT8/FP8 | ❌ | ✅ | ✅ | ❌ | ✅|128K |
|
||||
|ERNIE-4.5-0.3B | BF16/WINT8/FP8 | ❌ | ✅ | ✅ | ❌ | ✅| 128K |
|
||||
|ERNIE-4.5-300B-A47B|BF16\WINT4\WINT8\W4A8C8\WINT2\FP8|✅|✅|✅|✅|✅|128K|
|
||||
|ERNIE-4.5-300B-A47B-Base|BF16/WINT4/WINT8|✅|✅|✅|⛔|✅|128K|
|
||||
|ERNIE-4.5-VL-424B-A47B|BF16/WINT4/WINT8|🚧|✅|🚧|⛔|🚧|128K|
|
||||
|ERNIE-4.5-VL-28B-A3B|BF16/WINT4/WINT8|⛔|✅|🚧|⛔|🚧|128K|
|
||||
|ERNIE-4.5-21B-A3B|BF16/WINT4/WINT8/FP8|⛔|✅|✅|✅|✅|128K|
|
||||
|ERNIE-4.5-21B-A3B-Base|BF16/WINT4/WINT8/FP8|⛔|✅|✅|⛔|✅|128K|
|
||||
|ERNIE-4.5-0.3B|BF16/WINT8/FP8|⛔|✅|✅|⛔|✅|128K|
|
||||
|QWEN3-MOE|BF16/WINT4/WINT8/FP8|⛔|✅|✅|🚧|✅|128K|
|
||||
|QWEN3|BF16/WINT8/FP8|⛔|✅|✅|🚧|✅|128K|
|
||||
|QWEN-VL|BF16/WINT8/FP8|⛔|✅|✅|🚧|⛔|128K|
|
||||
|QWEN2|BF16/WINT8/FP8|⛔|✅|✅|🚧|✅|128K|
|
||||
|DEEPSEEK-V3|BF16/WINT4|⛔|✅|🚧|🚧|✅|128K|
|
||||
|DEEPSEEK-R1|BF16/WINT4|⛔|✅|🚧|🚧|✅|128K|
|
||||
|
||||
```
|
||||
✅ Supported 🚧 In Progress ⛔ No Plan
|
||||
```
|
||||
|
||||
## Supported Hardware
|
||||
|
||||
| Model | [NVIDIA GPU](./get_started/installation/nvidia_gpu.md) |[Kunlunxin XPU](./get_started/installation/kunlunxin_xpu.md) | Ascend NPU | [Hygon DCU](./get_started/installation/hygon_dcu.md) | [Iluvatar GPU](./get_started/installation/iluvatar_gpu.md) | [MetaX GPU](./get_started/installation/metax_gpu.md.md) | [Enflame GCU](./get_started/installation/Enflame_gcu.md) |
|
||||
|:------|---------|------------|----------|-------------|-----------|-------------|-------------|
|
||||
| ERNIE4.5-VL-424B-A47B | ✅ | 🚧 | 🚧 | ⛔ | ⛔ | ⛔ | ⛔ |
|
||||
| ERNIE4.5-300B-A47B | ✅ | ✅ | 🚧 | ✅ | ✅ | 🚧 | ✅ |
|
||||
| ERNIE4.5-VL-28B-A3B | ✅ | 🚧 | 🚧 | ⛔ | 🚧 | ⛔ | ⛔ |
|
||||
| ERNIE4.5-21B-A3B | ✅ | ✅ | 🚧 | ✅ | ✅ | ✅ | ✅ |
|
||||
| ERNIE4.5-0.3B | ✅ | ✅ | 🚧 | ✅ | ✅ | ✅ | ✅ |
|
||||
|
||||
```
|
||||
✅ Supported 🚧 In Progress ⛔ No Plan
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
|
@@ -192,9 +192,6 @@ return_token_ids: Optional[bool] = None
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
# Directly passes the token ID list of the prompt, skipping the text encoding step (default None means using text input).
|
||||
|
||||
max_streaming_response_tokens: Optional[int] = None
|
||||
# Maximum number of tokens returned at a time during streaming output (default None means no limit).
|
||||
|
||||
disable_chat_template: Optional[bool] = False
|
||||
# Whether to disable chat template rendering, using raw input directly (default False means template is enabled).
|
||||
|
||||
@@ -369,9 +366,6 @@ return_token_ids: Optional[bool] = None
|
||||
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
# Directly passes the token ID list of the prompt, skipping the text encoding step (default None means using text input).
|
||||
|
||||
max_streaming_response_tokens: Optional[int] = None
|
||||
# Maximum number of tokens returned at a time during streaming output (default None means no limit).
|
||||
```
|
||||
|
||||
### Overview of Return Parameters
|
||||
|
71
docs/online_serving/graceful_shutdown_service.md
Normal file
@@ -0,0 +1,71 @@
|
||||
# Graceful Service Node Shutdown Solution
|
||||
|
||||
## 1. Core Objective
|
||||
Achieve graceful shutdown of service nodes, ensuring no in-flight user requests are lost during service termination while maintaining overall cluster availability.
|
||||
|
||||
## 2. Solution Overview
|
||||
This solution combines **Nginx reverse proxy**, **Gunicorn server**, **Uvicorn server**, and **FastAPI** working in collaboration to achieve the objective.
|
||||
|
||||

|
||||
|
||||
## 3. Component Introduction
|
||||
|
||||
### 1. Nginx: Traffic Entry Point and Load Balancer
|
||||
- **Functions**:
|
||||
- Acts as a reverse proxy, receiving all external client requests and distributing them to upstream Gunicorn worker nodes according to load balancing policies.
|
||||
- Actively monitors backend node health status through health check mechanisms.
|
||||
- Enables instantaneous removal of problematic nodes from the service pool through configuration management, achieving traffic switching.
|
||||
|
||||
### 2. Gunicorn: WSGI HTTP Server (Process Manager)
|
||||
- **Functions**:
|
||||
- Serves as the master process, managing multiple Uvicorn worker child processes.
|
||||
- Receives external signals (e.g., `SIGTERM`) and coordinates the graceful shutdown process for all child processes.
|
||||
- Daemonizes worker processes and automatically restarts them upon abnormal termination, ensuring service robustness.
|
||||
|
||||
### 3. Uvicorn: ASGI Server (Worker Process)
|
||||
- **Functions**:
|
||||
- Functions as a Gunicorn-managed worker, actually handling HTTP requests.
|
||||
- Runs the FastAPI application instance, processing specific business logic.
|
||||
- Implements the ASGI protocol, supporting asynchronous request processing for high performance.
|
||||
|
||||
---
|
||||
|
||||
## Advantages
|
||||
|
||||
1. **Nginx**:
|
||||
- Can quickly isolate faulty nodes, ensuring overall service availability.
|
||||
- Allows configuration updates without downtime using `nginx -s reload`, making it transparent to users.
|
||||
|
||||
2. **Gunicorn** (Compared to Uvicorn's native multi-worker mode):
|
||||
- **Mature Process Management**: Built-in comprehensive process spawning, recycling, and management logic, eliminating the need for custom implementation.
|
||||
- **Process Daemon Capability**: The Gunicorn Master automatically forks new Workers if they crash, whereas in Uvicorn's `--workers` mode, any crashed process is not restarted and requires an external daemon.
|
||||
- **Rich Configuration**: Offers numerous parameters for adjusting timeouts, number of workers, restart policies, etc.
|
||||
|
||||
3. **Uvicorn**:
|
||||
- Extremely fast, built on uvloop and httptools.
|
||||
- Natively supports graceful shutdown: upon receiving a shutdown signal, it stops accepting new connections and waits for existing requests to complete before exiting.
|
||||
|
||||
---
|
||||
|
||||
## Graceful Shutdown Procedure
|
||||
|
||||
When a specific node needs to be taken offline, the steps are as follows:
|
||||
|
||||
1. **Nginx Monitors Node Health Status**:
|
||||
- Monitors the node's health status by periodically sending health check requests to it.
|
||||
|
||||
2. **Removal from Load Balancing**:
|
||||
- Modify the Nginx configuration to mark the target node as `down` and reload the Nginx configuration.
|
||||
- Subsequently, all new requests will no longer be sent to the target node.
|
||||
|
||||
3. **Gunicorn Server**:
|
||||
- Monitors for stop signals. Upon receiving a stop signal (e.g., `SIGTERM`), it relays this signal to all Uvicorn child processes.
|
||||
|
||||
4. **Sending the Stop Signal**:
|
||||
- Send a `SIGTERM` signal to the Uvicorn process on the target node, triggering Uvicorn's graceful shutdown process.
|
||||
|
||||
5. **Waiting for Request Processing**:
|
||||
- Wait for a period slightly longer than `timeout_graceful_shutdown` before forcefully terminating the service, allowing the node sufficient time to complete processing all received requests.
|
||||
|
||||
6. **Shutdown Completion**:
|
||||
- The node has now processed all remaining requests and exited safely.
|
BIN
docs/online_serving/images/graceful_shutdown.png
Normal file
After Width: | Height: | Size: 180 KiB |
@@ -20,7 +20,12 @@ After FastDeploy is launched, it supports continuous monitoring of the FastDeplo
|
||||
| `fastdeploy:gpu_cache_usage_perc` | Gauge | GPU KV-cache usage rate | Percentage |
|
||||
| `fastdeploy:request_params_max_tokens` | Histogram | Distribution of max_tokens for requests | Count |
|
||||
| `fastdeploy:request_success_total` | Counter | Number of successfully processed requests | Count |
|
||||
|
||||
| `fastdeploy:cache_config_info` | Gauge | Information of the engine's CacheConfig | Count |
|
||||
| `fastdeploy:available_batch_size` | Gauge | Number of requests that can still be inserted during the Decode phase| Count |
|
||||
| `fastdeploy:hit_req_rate` | Gauge | Request-level prefix cache hit rate | Percentage |
|
||||
| `fastdeploy:hit_token_rate` | Gauge | Token-level prefix cache hit rate | Percentage |
|
||||
| `fastdeploy:cpu_hit_token_rate` | Gauge | Token-level CPU prefix cache hit rate | Percentage |
|
||||
| `fastdeploy:gpu_hit_token_rate` | Gauge | Token-level GPU prefix cache hit rate | Percentage |
|
||||
## Accessing Metrics
|
||||
|
||||
- Access URL: `http://localhost:8000/metrics`
|
||||
|
@@ -37,7 +37,7 @@ When using FastDeploy to deploy models (including offline inference and service
|
||||
| ```reasoning_parser``` | `str` | Specify the reasoning parser to extract reasoning content from model output |
|
||||
| ```use_cudagraph``` | `bool` | Whether to use cuda graph, default False. It is recommended to read [graph_optimization.md](./features/graph_optimization.md) carefully before opening. Custom all-reduce needs to be enabled at the same time in multi-card scenarios. |
|
||||
| ```graph_optimization_config``` | `dict[str]` | Can configure parameters related to calculation graph optimization, the default value is'{"use_cudagraph":false, "graph_opt_level":0, "cudagraph_capture_sizes": null }',Detailed description reference [graph_optimization.md](./features/graph_optimization.md)|
|
||||
| ```disable_custom_all_reduce``` | `bool` | Disable Custom all-reduce, default: False |
|
||||
| ```enable_custom_all_reduce``` | `bool` | Enable Custom all-reduce, default: False |
|
||||
| ```splitwise_role``` | `str` | Whether to enable splitwise inference, default value: mixed, supported parameters: ["mixed", "decode", "prefill"] |
|
||||
| ```innode_prefill_ports``` | `str` | Internal engine startup ports for prefill instances (only required for single-machine PD separation), default: None |
|
||||
| ```guided_decoding_backend``` | `str` | Specify the guided decoding backend to use, supports `auto`, `xgrammar`, `off`, default: `off` |
|
||||
@@ -51,7 +51,7 @@ When using FastDeploy to deploy models (including offline inference and service
|
||||
| ```chat_template``` | `str` | Specify the template used for model concatenation, It supports both string input and file path input. The default value is None. If not specified, the model's default template will be used. |
|
||||
| ```tool_call_parser``` | `str` | Specify the function call parser to be used for extracting function call content from the model's output. |
|
||||
| ```tool_parser_plugin``` | `str` | Specify the file path of the tool parser to be registered, so as to register parsers that are not in the code repository. The code format within these parsers must adhere to the format used in the code repository. |
|
||||
| ```lm_head_fp32``` | `bool` | Specify the dtype of the lm_head layer as FP32. |
|
||||
| ```load_choices``` | `str` | By default, the "default" loader is used for weight loading. To load Torch weights or enable weight acceleration, "default_v1" must be used.|
|
||||
|
||||
## 1. Relationship between KVCache allocation, ```num_gpu_blocks_override``` and ```block_size```?
|
||||
|
||||
|
BIN
docs/quantization/images/wint2.png
Normal file
After Width: | Height: | Size: 81 KiB |
@@ -1,21 +1,101 @@
|
||||
# WINT2 Quantization
|
||||
|
||||
Weights are compressed offline using the CCQ (Convolutional Coding Quantization) method. The actual stored numerical type of weights is INT8, with 4 weights packed into each INT8 value, equivalent to 2 bits per weight. Activations are not quantized. During inference, weights are dequantized and decoded in real-time to BF16 numerical type, and calculations are performed using BF16 numerical type.
|
||||
Weights are compressed offline using the [CCQ (Convolutional Coding Quantization)](https://arxiv.org/pdf/2507.07145) method. The actual stored numerical type of weights is INT8, with 4 weights packed into each INT8 value, equivalent to 2 bits per weight. Activations are not quantized. During inference, weights are dequantized and decoded in real-time to BF16 numerical type, and calculations are performed using BF16 numerical type.
|
||||
- **Supported Hardware**: GPU
|
||||
- **Supported Architecture**: MoE architecture
|
||||
This method relies on the convolution algorithm to use overlapping bits to map 2-bit values to a larger numerical representation space, so that the model weight quantization retains more information of the original data while compressing the true value to an extremely low 2-bit size. The general principle can be seen in the figure below:
|
||||

|
||||
|
||||
CCQ WINT2 is generally used in resource-constrained and low-threshold scenarios. Taking ERNIE-4.5-300B-A47B as an example, weights are compressed to 89GB, supporting single-card deployment on 141GB H20.
|
||||
|
||||
## Run WINT2 Inference Service
|
||||
## Executing WINT2 Offline Inference
|
||||
- When executing TP2/TP4 models, you can change the `model_name_or_path` and `tensor_parallel_size` parameters.
|
||||
```
|
||||
model_name_or_path = "baidu/ERNIE-4.5-300B-A47B-2Bits-Paddle"
|
||||
prompts = ["解析三首李白的诗"]
|
||||
from fastdeploy import LLM, SamplingParams
|
||||
sampling_params = SamplingParams(temperature=0.7, top_p=0, max_tokens=128)
|
||||
llm = LLM(model=model_name_or_path, tensor_parallel_size=1, use_cudagraph=True,)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
print(outputs)
|
||||
|
||||
```
|
||||
|
||||
## Run WINT2 Inference Service
|
||||
- When executing TP2/TP4 models, you can change the `--model` and `tensor-parallel-size` parameters.
|
||||
```
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model baidu/ERNIE-4.5-300B-A47B-2Bits-Paddle \
|
||||
--port 8180 --engine-worker-queue-port 8181 \
|
||||
--cache-queue-port 8182 --metrics-port 8182 \
|
||||
--tensor-parallel-size 1 \
|
||||
--max-model-len 32768 \
|
||||
--max-num-seqs 32
|
||||
--model baidu/ERNIE-4.5-300B-A47B-2Bits-Paddle \
|
||||
--port 8180 \
|
||||
--metrics-port 8181 \
|
||||
--engine-worker-queue-port 8182 \
|
||||
--cache-queue-port 8183 \
|
||||
--tensor-parallel-size 1 \
|
||||
--max-model-len 32768 \
|
||||
--use-cudagraph \
|
||||
--enable-prefix-caching \
|
||||
--enable-chunked-prefill \
|
||||
--max-num-seqs 256
|
||||
```
|
||||
|
||||
## Request the Service
|
||||
After starting the service, the following output indicates successful initialization:
|
||||
|
||||
```shell
|
||||
api_server.py[line:91] Launching metrics service at http://0.0.0.0:8181/metrics
|
||||
api_server.py[line:94] Launching chat completion service at http://0.0.0.0:8180/v1/chat/completions
|
||||
api_server.py[line:97] Launching completion service at http://0.0.0.0:8180/v1/completions
|
||||
INFO: Started server process [13909]
|
||||
INFO: Waiting for application startup.
|
||||
INFO: Application startup complete.
|
||||
INFO: Uvicorn running on http://0.0.0.0:8180 (Press CTRL+C to quit)
|
||||
```
|
||||
|
||||
### Health Check
|
||||
|
||||
Verify service status (HTTP 200 indicates success):
|
||||
|
||||
```shell
|
||||
curl -i http://0.0.0.0:8180/health
|
||||
```
|
||||
|
||||
### cURL Request
|
||||
|
||||
Send requests to the service with the following command:
|
||||
|
||||
```shell
|
||||
curl -X POST "http://0.0.0.0:1822/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Write me a poem about large language model."}
|
||||
],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
### Python Client (OpenAI-compatible API)
|
||||
|
||||
FastDeploy's API is OpenAI-compatible. You can also use Python for requests:
|
||||
|
||||
```python
|
||||
import openai
|
||||
host = "0.0.0.0"
|
||||
port = "8180"
|
||||
client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="null",
|
||||
messages=[
|
||||
{"role": "system", "content": "I'm a helpful AI assistant."},
|
||||
{"role": "user", "content": "Write me a poem about large language model."},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta:
|
||||
print(chunk.choices[0].delta.content, end='')
|
||||
print('\n')
|
||||
```
|
||||
|
||||
By specifying `--model baidu/ERNIE-4.5-300B-A47B-2Bits-Paddle`, the offline quantized WINT2 model can be automatically downloaded from AIStudio. In the config.json file of this model, there will be WINT2 quantization-related configuration information, so there's no need to set `--quantization` when starting the inference service.
|
||||
@@ -54,9 +134,7 @@ On the ERNIE-4.5-300B-A47B model, comparison of WINT2 vs WINT4 performance:
|
||||
|
||||
| Test Set | Dataset Size | WINT4 | WINT2 |
|
||||
|---------|---------|---------|---------|
|
||||
| IFEval |500|88.17 | 85.40 |
|
||||
|BBH|6511|94.43|92.02|
|
||||
|DROP|9536|91.17|89.97|
|
||||
|GSM8K|1319|96.21|95.98|
|
||||
|CMath|600|96.50|96.00|
|
||||
|CMMLU|11477|89.92|86.22|
|
||||
| IFEval |500|88.17 | 85.95 |
|
||||
|BBH|6511|94.43|90.06|
|
||||
|DROP|9536|91.17|89.32|
|
||||
|CMMLU|11477|89.92|86.55|
|
||||
|
BIN
docs/quantization/wint2.png
Normal file
After Width: | Height: | Size: 81 KiB |
@@ -2,9 +2,9 @@
|
||||
|
||||
FastDeploy currently supports the following models, which can be downloaded automatically during FastDeploy deployment.Specify the ``model`` parameter as the model name in the table below to automatically download model weights (all supports resumable downloads). The following three download sources are supported:
|
||||
|
||||
- 1. Search for corresponding Paddle-version ERNIE models on [AIStudio/PaddlePaddle](https://aistudio.baidu.com/modelsoverview), e.g., `ERNIE-4.5-0.3B-Paddle`
|
||||
- 2. Download Paddle-version ERNIE models from [HuggingFace/baidu/models](https://huggingface.co/baidu/models), e.g., `baidu/ERNIE-4.5-0.3B-Paddle`
|
||||
- 3. Search for corresponding Paddle-version ERNIE models on [ModelScope/PaddlePaddle](https://www.modelscope.cn/models?name=PaddlePaddle&page=1&tabKey=task), e.g., `ERNIE-4.5-0.3B-Paddle`
|
||||
- [AIStudio](https://aistudio.baidu.com/modelsoverview)
|
||||
- [ModelScope](https://www.modelscope.cn/models)
|
||||
- [HuggingFace](https://huggingface.co/models)
|
||||
|
||||
When using automatic download, the default download source is AIStudio. Users can modify the default download source by setting the ``FD_MODEL_SOURCE`` environment variable, which can be set to “AISTUDIO”, ‘MODELSCOPE’ or “HUGGINGFACE”. The default download path is ``~/`` (i.e., the user's home directory). Users can modify the default download path by setting the ``FD_MODEL_CACHE`` environment variable, e.g.:
|
||||
|
||||
@@ -13,25 +13,40 @@ export FD_MODEL_SOURCE=AISTUDIO # "AISTUDIO", "MODELSCOPE" or "HUGGINGFACE"
|
||||
export FD_MODEL_CACHE=/ssd1/download_models
|
||||
```
|
||||
|
||||
| Model Name | Context Length | Quantization | Minimum Deployment Resources | Notes |
|
||||
| :------------------------------------------ | :------------- | :----------- | :--------------------------- | :----------------------------------------------------------------------------------------- |
|
||||
| baidu/ERNIE-4.5-VL-424B-A47B-Paddle | 32K/128K | WINT4 | 4*80G GPU VRAM/1T RAM | Chunked Prefill required for 128K |
|
||||
| baidu/ERNIE-4.5-VL-424B-A47B-Paddle | 32K/128K | WINT8 | 8*80G GPU VRAM/1T RAM | Chunked Prefill required for 128K |
|
||||
| baidu/ERNIE-4.5-300B-A47B-Paddle | 32K/128K | WINT4 | 4*64G GPU VRAM/600G RAM | Chunked Prefill required for 128K |
|
||||
| baidu/ERNIE-4.5-300B-A47B-Paddle | 32K/128K | WINT8 | 8*64G GPU VRAM/600G RAM | Chunked Prefill required for 128K |
|
||||
| baidu/ERNIE-4.5-300B-A47B-2Bits-Paddle | 32K/128K | WINT2 | 1*141G GPU VRAM/600G RAM | Chunked Prefill required for 128K |
|
||||
| baidu/ERNIE-4.5-300B-A47B-W4A8C8-TP4-Paddle | 32K/128K | W4A8C8 | 4*64G GPU VRAM/160G RAM | Fixed 4-GPU setup, Chunked Prefill recommended |
|
||||
| baidu/ERNIE-4.5-300B-A47B-FP8-Paddle | 32K/128K | FP8 | 8*64G GPU VRAM/600G RAM | Chunked Prefill recommended, only supports PD Disaggragated Deployment with EP parallelism |
|
||||
| baidu/ERNIE-4.5-300B-A47B-Base-Paddle | 32K/128K | WINT4 | 4*64G GPU VRAM/600G RAM | Chunked Prefill recommended |
|
||||
| baidu/ERNIE-4.5-300B-A47B-Base-Paddle | 32K/128K | WINT8 | 8*64G GPU VRAM/600G RAM | Chunked Prefill recommended |
|
||||
| baidu/ERNIE-4.5-VL-28B-A3B-Paddle | 32K | WINT4 | 1*24G GPU VRAM/128G RAM | Chunked Prefill required |
|
||||
| baidu/ERNIE-4.5-VL-28B-A3B-Paddle | 128K | WINT4 | 1*48G GPU VRAM/128G RAM | Chunked Prefill required |
|
||||
| baidu/ERNIE-4.5-VL-28B-A3B-Paddle | 32K/128K | WINT8 | 1*48G GPU VRAM/128G RAM | Chunked Prefill required |
|
||||
| baidu/ERNIE-4.5-21B-A3B-Paddle | 32K/128K | WINT4 | 1*24G GPU VRAM/128G RAM | Chunked Prefill required for 128K |
|
||||
| baidu/ERNIE-4.5-21B-A3B-Paddle | 32K/128K | WINT8 | 1*48G GPU VRAM/128G RAM | Chunked Prefill required for 128K |
|
||||
| baidu/ERNIE-4.5-21B-A3B-Base-Paddle | 32K/128K | WINT4 | 1*24G GPU VRAM/128G RAM | Chunked Prefill required for 128K |
|
||||
| baidu/ERNIE-4.5-21B-A3B-Base-Paddle | 32K/128K | WINT8 | 1*48G GPU VRAM/128G RAM | Chunked Prefill required for 128K |
|
||||
| baidu/ERNIE-4.5-0.3B-Paddle | 32K/128K | BF16 | 1*6G/12G GPU VRAM/2G RAM | |
|
||||
| baidu/ERNIE-4.5-0.3B-Base-Paddle | 32K/128K | BF16 | 1*6G/12G GPU VRAM/2G RAM | |
|
||||
> ⭐ **Note**: Models marked with an asterisk can directly use **HuggingFace Torch weights** and support **FP8/WINT8/WINT4** as well as **BF16**. When running inference, you need to enable **`--load_choices "default_v1"`**.
|
||||
|
||||
> Example launch Command using baidu/ERNIE-4.5-21B-A3B-PT:
|
||||
```
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model baidu/ERNIE-4.5-0.3B-PT \
|
||||
--port 8180 \
|
||||
--metrics-port 8181 \
|
||||
--engine-worker-queue-port 8182 \
|
||||
--max-model-len 32768 \
|
||||
--max-num-seqs 32 \
|
||||
--load_choices "default_v1"
|
||||
```
|
||||
|
||||
## Large Language Models
|
||||
|
||||
These models accept text input.
|
||||
|
||||
|Models|DataType|Example HF Model|
|
||||
|-|-|-|
|
||||
|⭐ERNIE|BF16\WINT4\WINT8\W4A8C8\WINT2\FP8|baidu/ERNIE-4.5-VL-424B-A47B-Paddle;<br>baidu/ERNIE-4.5-300B-A47B-Paddle<br> [quick start](./get_started/ernie-4.5.md)   [best practice](./best_practices/ERNIE-4.5-300B-A47B-Paddle.md);<br>baidu/ERNIE-4.5-300B-A47B-2Bits-Paddle;<br>baidu/ERNIE-4.5-300B-A47B-W4A8C8-TP4-Paddle;<br>baidu/ERNIE-4.5-300B-A47B-FP8-Paddle;<br>baidu/ERNIE-4.5-300B-A47B-Base-Paddle;<br>[baidu/ERNIE-4.5-21B-A3B-Paddle](./best_practices/ERNIE-4.5-21B-A3B-Paddle.md);<br>baidu/ERNIE-4.5-21B-A3B-Base-Paddle;<br>baidu/ERNIE-4.5-0.3B-Paddle<br> [quick start](./get_started/quick_start.md)   [best practice](./best_practices/ERNIE-4.5-0.3B-Paddle.md);<br>baidu/ERNIE-4.5-0.3B-Base-Paddle, etc.|
|
||||
|⭐QWEN3-MOE|BF16/WINT4/WINT8/FP8|Qwen/Qwen3-235B-A22B;<br>Qwen/Qwen3-30B-A3B, etc.|
|
||||
|⭐QWEN3|BF16/WINT8/FP8|Qwen/qwen3-32B;<br>Qwen/qwen3-14B;<br>Qwen/qwen3-8B;<br>Qwen/qwen3-4B;<br>Qwen/qwen3-1.7B;<br>[Qwen/qwen3-0.6B](./get_started/quick_start_qwen.md), etc.|
|
||||
|⭐QWEN2.5|BF16/WINT8/FP8|Qwen/qwen2.5-72B;<br>Qwen/qwen2.5-32B;<br>Qwen/qwen2.5-14B;<br>Qwen/qwen2.5-7B;<br>Qwen/qwen2.5-3B;<br>Qwen/qwen2.5-1.5B;<br>Qwen/qwen2.5-0.5B, etc.|
|
||||
|⭐QWEN2|BF16/WINT8/FP8|Qwen/Qwen/qwen2-72B;<br>Qwen/Qwen/qwen2-7B;<br>Qwen/qwen2-1.5B;<br>Qwen/qwen2-0.5B;<br>Qwen/QwQ-32, etc.|
|
||||
|⭐DEEPSEEK|BF16/WINT4|unsloth/DeepSeek-V3.1-BF16;<br>unsloth/DeepSeek-V3-0324-BF16;<br>unsloth/DeepSeek-R1-BF16, etc.|
|
||||
|
||||
## Multimodal Language Models
|
||||
|
||||
These models accept multi-modal inputs (e.g., images and text).
|
||||
|
||||
|Models|DataType|Example HF Model|
|
||||
|-|-|-|
|
||||
| ERNIE-VL |BF16/WINT4/WINT8| baidu/ERNIE-4.5-VL-424B-A47B-Paddle<br> [quick start](./get_started/ernie-4.5-vl.md)   [best practice](./best_practices/ERNIE-4.5-VL-424B-A47B-Paddle.md) ;<br>baidu/ERNIE-4.5-VL-28B-A3B-Paddle<br> [quick start](./get_started/quick_start_vl.md)   [best practice](./best_practices/ERNIE-4.5-VL-28B-A3B-Paddle.md) ;|
|
||||
| QWEN-VL |BF16/WINT4/FP8| Qwen/Qwen2.5-VL-72B-Instruct;<br>Qwen/Qwen2.5-VL-32B-Instruct;<br>Qwen/Qwen2.5-VL-7B-Instruct;<br>Qwen/Qwen2.5-VL-3B-Instruct|
|
||||
|
||||
More models are being supported. You can submit requests for new model support via [Github Issues](https://github.com/PaddlePaddle/FastDeploy/issues).
|
||||
|
@@ -89,4 +89,4 @@ for chunk in response:
|
||||
print('\n')
|
||||
```
|
||||
|
||||
For detailed OpenAI protocol specifications, see [OpenAI Chat Compeltion API](https://platform.openai.com/docs/api-reference/chat/create). Differences from the standard OpenAI protocol are documented in [OpenAI Protocol-Compatible API Server](../online_serving/README.md).
|
||||
For detailed OpenAI protocol specifications, see [OpenAI Chat Completion API](https://platform.openai.com/docs/api-reference/chat/create). Differences from the standard OpenAI protocol are documented in [OpenAI Protocol-Compatible API Server](../online_serving/README.md).
|
||||
|
@@ -19,23 +19,24 @@ ERNIE-4.5-0.3B 各量化精度,在下列硬件上部署所需要的最小卡
|
||||
### 1.2 安装fastdeploy
|
||||
- 安装请参考[Fastdeploy Installation](../get_started/installation/README.md)完成安装。
|
||||
|
||||
- 模型下载,请参考[支持模型列表](../supported_models.md)。**请注意使用Fastdeploy部署需要Paddle后缀的模型**
|
||||
- 模型下载,请参考[支持模型列表](../supported_models.md)。
|
||||
|
||||
## 二、如何使用
|
||||
### 2.1 基础:启动服务
|
||||
通过下列命令启动服务
|
||||
```bash
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model baidu/ERNIE-4.5-0.3B-Paddle \
|
||||
--tensor-parallel-size 1 \
|
||||
--quantization wint4 \
|
||||
--max-model-len 32768 \
|
||||
--max-num-seqs 128
|
||||
--max-num-seqs 128 \
|
||||
--load_choices "default_v1"
|
||||
```
|
||||
其中:
|
||||
- `--quantization`: 表示模型采用的量化策略。不同量化策略,模型的性能和精度也会不同。可选值包括:`wint8` / `wint4` / `block_wise_fp8`(需要Hopper架构)。
|
||||
- `--max-model-len`:表示当前部署的服务所支持的最长Token数量。设置得越大,模型可支持的上下文长度也越大,但相应占用的显存也越多,可能影响并发数。
|
||||
- `--load_choices`: 表示loader的版本,"default_v1"表示启用v1版本的loader,具有更快的加载速度和更少的内存使用。
|
||||
|
||||
更多的参数含义与默认设置,请参见[FastDeploy参数说明](../parameters.md)。
|
||||
|
||||
@@ -43,16 +44,14 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
#### 2.2.1 评估应用场景,正确设置参数
|
||||
结合应用场景,评估平均输入长度、平均输出长度、最大上下文长度。例如,平均输入长度为1000,输出长度为30000,那么建议设置为 32768
|
||||
- 根据最大上下文长度,设置`max-model-len`
|
||||
- **启用服务管理全局 Block**
|
||||
```
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
```
|
||||
|
||||
#### 2.2.2 Prefix Caching
|
||||
**原理:** Prefix Caching的核心思想是通过缓存输入序列的中间计算结果(KV Cache),避免重复计算,从而加速具有相同前缀的多个请求的响应速度。具体参考[prefix-cache](../features/prefix_caching.md)
|
||||
|
||||
**启用方式:**
|
||||
在启动参数下增加下列两行,其中`--enable-prefix-caching`表示启用前缀缓存,`--swap-space`表示在GPU缓存的基础上,额外开启CPU缓存,大小为GB,应根据机器实际情况调整。建议取值为`(机器总内存 - 模型大小) * 20%`。如果因为其他程序占用内存等原因导致服务启动失败,可以尝试减小`--swap-space`的值。
|
||||
自2.2版本开始(包括develop分支),Prefix Caching已经默认开启。
|
||||
|
||||
对于2.1及更早的版本,需要手动开启。其中`--enable-prefix-caching`表示启用前缀缓存,`--swap-space`表示在GPU缓存的基础上,额外开启CPU缓存,大小为GB,应根据机器实际情况调整。建议取值为`(机器总内存 - 模型大小) * 20%`。如果因为其他程序占用内存等原因导致服务启动失败,可以尝试减小`--swap-space`的值。
|
||||
```
|
||||
--enable-prefix-caching
|
||||
--swap-space 50
|
||||
@@ -61,7 +60,10 @@ export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
#### 2.2.3 Chunked Prefill
|
||||
**原理:** 采用分块策略,将预填充(Prefill)阶段请求拆解为小规模子任务,与解码(Decode)请求混合批处理执行。可以更好地平衡计算密集型(Prefill)和访存密集型(Decode)操作,优化GPU资源利用率,减少单次Prefill的计算量和显存占用,从而降低显存峰值,避免显存不足的问题。 具体请参考[Chunked Prefill](../features/chunked_prefill.md)
|
||||
|
||||
**启用方式:** 在启动参数下增加即可
|
||||
**启用方式:**
|
||||
自2.2版本开始(包括develop分支),Chunked Prefill已经默认开启。
|
||||
|
||||
对于2.1及更早的版本,需要手动开启。
|
||||
```
|
||||
--enable-chunked-prefill
|
||||
```
|
||||
|
@@ -19,23 +19,24 @@ ERNIE-4.5-21B-A3B 各量化精度,在下列硬件上部署所需要的最小
|
||||
### 1.2 安装fastdeploy
|
||||
- 安装,请参考[Fastdeploy Installation](../get_started/installation/README.md)完成安装。
|
||||
|
||||
- 模型下载,请参考[支持模型列表](../supported_models.md)。**请注意使用Fastdeploy部署需要Paddle后缀的模型**
|
||||
- 模型下载,请参考[支持模型列表](../supported_models.md)。
|
||||
|
||||
## 二、如何使用
|
||||
### 2.1 基础:启动服务
|
||||
通过下列命令启动服务
|
||||
```bash
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model baidu/ERNIE-4.5-21B-A3B-Paddle \
|
||||
--tensor-parallel-size 1 \
|
||||
--quantization wint4 \
|
||||
--max-model-len 32768 \
|
||||
--max-num-seqs 128
|
||||
--max-num-seqs 128 \
|
||||
--load_choices "default_v1"
|
||||
```
|
||||
其中:
|
||||
- `--quantization`: 表示模型采用的量化策略。不同量化策略,模型的性能和精度也会不同。可选值包括:`wint8` / `wint4` / `block_wise_fp8`(需要Hopper架构)。
|
||||
- `--max-model-len`:表示当前部署的服务所支持的最长Token数量。设置得越大,模型可支持的上下文长度也越大,但相应占用的显存也越多,可能影响并发数。
|
||||
- `--load_choices`: 表示loader的版本,"default_v1"表示启用v1版本的loader,具有更快的加载速度和更少的内存使用。
|
||||
|
||||
更多的参数含义与默认设置,请参见[FastDeploy参数说明](../parameters.md)。
|
||||
|
||||
@@ -43,16 +44,14 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
#### 2.2.1 评估应用场景,正确设置参数
|
||||
结合应用场景,评估平均输入长度、平均输出长度、最大上下文长度。例如,平均输入长度为1000,输出长度为30000,那么建议设置为 32768
|
||||
- 根据最大上下文长度,设置`max-model-len`
|
||||
- **启用服务管理全局 Block**
|
||||
```
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
```
|
||||
|
||||
#### 2.2.2 Prefix Caching
|
||||
**原理:** Prefix Caching的核心思想是通过缓存输入序列的中间计算结果(KV Cache),避免重复计算,从而加速具有相同前缀的多个请求的响应速度。具体参考[prefix-cache](../features/prefix_caching.md)
|
||||
|
||||
**启用方式:**
|
||||
在启动参数下增加下列两行,其中`--enable-prefix-caching`表示启用前缀缓存,`--swap-space`表示在GPU缓存的基础上,额外开启CPU缓存,大小为GB,应根据机器实际情况调整。建议取值为`(机器总内存 - 模型大小) * 20%`。如果因为其他程序占用内存等原因导致服务启动失败,可以尝试减小`--swap-space`的值。
|
||||
自2.2版本开始(包括develop分支),Prefix Caching已经默认开启。
|
||||
|
||||
对于2.1及更早的版本,需要手动开启。其中`--enable-prefix-caching`表示启用前缀缓存,`--swap-space`表示在GPU缓存的基础上,额外开启CPU缓存,大小为GB,应根据机器实际情况调整。建议取值为`(机器总内存 - 模型大小) * 20%`。如果因为其他程序占用内存等原因导致服务启动失败,可以尝试减小`--swap-space`的值。
|
||||
```
|
||||
--enable-prefix-caching
|
||||
--swap-space 50
|
||||
@@ -61,7 +60,10 @@ export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
#### 2.2.3 Chunked Prefill
|
||||
**原理:** 采用分块策略,将预填充(Prefill)阶段请求拆解为小规模子任务,与解码(Decode)请求混合批处理执行。可以更好地平衡计算密集型(Prefill)和访存密集型(Decode)操作,优化GPU资源利用率,减少单次Prefill的计算量和显存占用,从而降低显存峰值,避免显存不足的问题。 具体请参考[Chunked Prefill](../features/chunked_prefill.md)
|
||||
|
||||
**启用方式:** 在启动参数下增加即可
|
||||
**启用方式:**
|
||||
自2.2版本开始(包括develop分支),Chunked Prefill已经默认开启。
|
||||
|
||||
对于2.1及更早的版本,需要手动开启。
|
||||
```
|
||||
--enable-chunked-prefill
|
||||
```
|
||||
@@ -78,7 +80,9 @@ export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
|
||||
注:
|
||||
1. MTP当前暂不支持与Prefix Caching 、Chunked Prefill 、CUDAGraph同时使用。
|
||||
2. MTP当前暂不支持服务管理全局 Block, 即不要开启`export ENABLE_V1_KVCACHE_SCHEDULER=1`
|
||||
- 需要通过指定`export FD_DISABLE_CHUNKED_PREFILL=1` 关闭Chunked Prefill。
|
||||
- 指定`speculative-config`时,会自动关闭Prefix Caching功能。
|
||||
2. MTP当前暂不支持服务管理全局 Block, 指定`speculative-config`时,会自动关闭全局Block调度器。
|
||||
3. MTP当前暂不支持和拒绝采样同时使用,即不要开启`export FD_SAMPLING_CLASS=rejection`
|
||||
|
||||
#### 2.2.5 CUDAGraph
|
||||
@@ -111,7 +115,6 @@ export FD_SAMPLING_CLASS=rejection
|
||||
# prefill
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
export INFERENCE_MSG_QUEUE_ID=1315
|
||||
export FLAGS_max_partition_size=2048
|
||||
export FD_ATTENTION_BACKEND=FLASH_ATTN
|
||||
export FD_LOG_DIR="prefill_log"
|
||||
|
||||
@@ -131,7 +134,6 @@ python -m fastdeploy.entrypoints.openai.api_server --model baidu/ERNIE-4.5-21B-A
|
||||
# decode
|
||||
export CUDA_VISIBLE_DEVICES=4,5,6,7
|
||||
export INFERENCE_MSG_QUEUE_ID=1215
|
||||
export FLAGS_max_partition_size=2048
|
||||
export FD_LOG_DIR="decode_log"
|
||||
|
||||
quant_type=block_wise_fp8
|
||||
|
@@ -16,23 +16,24 @@ ERNIE-4.5-300B-A47B各量化精度,在下列硬件上部署所需要的最小
|
||||
### 1.2 安装fastdeploy
|
||||
- 安装,请参考[Fastdeploy Installation](../get_started/installation/README.md)完成安装。
|
||||
|
||||
- 模型下载,请参考[支持模型列表](../supported_models.md)。**请注意使用Fastdeploy部署需要Paddle后缀的模型**
|
||||
- 模型下载,请参考[支持模型列表](../supported_models.md)。
|
||||
|
||||
## 二、如何使用
|
||||
### 2.1 基础:启动服务
|
||||
通过下列命令启动服务
|
||||
```bash
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
|
||||
--tensor-parallel-size 8 \
|
||||
--quantization wint4 \
|
||||
--max-model-len 32768 \
|
||||
--max-num-seqs 128
|
||||
--max-num-seqs 128 \
|
||||
--load_choices "default_v1"
|
||||
```
|
||||
其中:
|
||||
- `--quantization`: 表示模型采用的量化策略。不同量化策略,模型的性能和精度也会不同。可选值包括:`wint8` / `wint4` / `block_wise_fp8`(需要Hopper架构)。
|
||||
- `--max-model-len`:表示当前部署的服务所支持的最长Token数量。设置得越大,模型可支持的上下文长度也越大,但相应占用的显存也越多,可能影响并发数。
|
||||
- `--load_choices`: 表示loader的版本,"default_v1"表示启用v1版本的loader,具有更快的加载速度和更少的内存使用。
|
||||
|
||||
更多的参数含义与默认设置,请参见[FastDeploy参数说明](../parameters.md)。
|
||||
|
||||
@@ -40,17 +41,14 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
#### 2.2.1 评估应用场景,正确设置参数
|
||||
结合应用场景,评估平均输入长度、平均输出长度、最大上下文长度
|
||||
- 根据最大上下文长度,设置`max-model-len`。例如,平均输入长度为1000,输出长度为30000,那么建议设置为 32768
|
||||
- **启用服务管理全局 Block**
|
||||
|
||||
```
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
```
|
||||
|
||||
#### 2.2.2 Prefix Caching
|
||||
**原理:** Prefix Caching的核心思想是通过缓存输入序列的中间计算结果(KV Cache),避免重复计算,从而加速具有相同前缀的多个请求的响应速度。具体参考[prefix-cache](../features/prefix_caching.md)
|
||||
|
||||
**启用方式:**
|
||||
在启动参数下增加下列两行,其中`--enable-prefix-caching`表示启用前缀缓存,`--swap-space`表示在GPU缓存的基础上,额外开启CPU缓存,大小为GB,应根据机器实际情况调整。建议取值为`(机器总内存 - 模型大小) * 20%`。如果因为其他程序占用内存等原因导致服务启动失败,可以尝试减小`--swap-space`的值。
|
||||
自2.2版本开始(包括develop分支),Prefix Caching已经默认开启。
|
||||
|
||||
对于2.1及更早的版本,需要手动开启。其中`--enable-prefix-caching`表示启用前缀缓存,`--swap-space`表示在GPU缓存的基础上,额外开启CPU缓存,大小为GB,应根据机器实际情况调整。建议取值为`(机器总内存 - 模型大小) * 20%`。如果因为其他程序占用内存等原因导致服务启动失败,可以尝试减小`--swap-space`的值。
|
||||
```
|
||||
--enable-prefix-caching
|
||||
--swap-space 50
|
||||
@@ -59,7 +57,10 @@ export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
#### 2.2.3 Chunked Prefill
|
||||
**原理:** 采用分块策略,将预填充(Prefill)阶段请求拆解为小规模子任务,与解码(Decode)请求混合批处理执行。可以更好地平衡计算密集型(Prefill)和访存密集型(Decode)操作,优化GPU资源利用率,减少单次Prefill的计算量和显存占用,从而降低显存峰值,避免显存不足的问题。 具体请参考[Chunked Prefill](../features/chunked_prefill.md)
|
||||
|
||||
**启用方式:** 在启动参数下增加即可
|
||||
**启用方式:**
|
||||
自2.2版本开始(包括develop分支),Chunked Prefill已经默认开启。
|
||||
|
||||
对于2.1及更早的版本,需要手动开启。
|
||||
```
|
||||
--enable-chunked-prefill
|
||||
```
|
||||
@@ -75,7 +76,9 @@ export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
```
|
||||
注:
|
||||
1. MTP当前暂不支持与Prefix Caching 、Chunked Prefill 、CUDAGraph同时使用。
|
||||
2. MTP当前暂不支持服务管理全局 Block, 即不要开启`export ENABLE_V1_KVCACHE_SCHEDULER=1`
|
||||
- 需要通过指定`export FD_DISABLE_CHUNKED_PREFILL=1` 关闭Chunked Prefill。
|
||||
- 指定`speculative-config`时,会自动关闭Prefix Caching功能。
|
||||
2. MTP当前暂不支持服务管理全局 Block, 指定`speculative-config`时,会自动关闭全局Block调度器。
|
||||
3. MTP当前暂不支持和拒绝采样同时使用,即不要开启`export FD_SAMPLING_CLASS=rejection`
|
||||
|
||||
#### 2.2.5 W4A8C8量化
|
||||
@@ -88,6 +91,9 @@ export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
--model baidu/ERNIE-4.5-300B-A47B-W4A8C8-TP4-Paddle
|
||||
```
|
||||
|
||||
注:
|
||||
- W4A8C8量化的模型不支持通过`--load_choices "default_v1"`载入。
|
||||
|
||||
#### 2.2.6 拒绝采样
|
||||
**原理:**
|
||||
拒绝采样即从一个易于采样的提议分布(proposal distribution)中生成样本,避免显式排序从而达到提升采样速度的效果,对小尺寸的模型有较明显的提升。
|
||||
|
@@ -17,15 +17,10 @@
|
||||
|
||||
安装流程参考文档 [FastDeploy GPU 安装](../get_started/installation/nvidia_gpu.md)
|
||||
|
||||
> ⚠️ 注意事项
|
||||
> - FastDeploy只支持Paddle格式的模型,注意下载Paddle后缀的模型
|
||||
> - 使用模型名称会自动下载模型,如果已经下载过模型,可以直接使用模型下载位置的绝对路径
|
||||
|
||||
## 二、如何使用
|
||||
### 2.1 基础:启动服务
|
||||
**示例1:** 4090上单卡部署32K上下文的服务
|
||||
```shell
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model baidu/ERNIE-4.5-VL-28B-A3B-Paddle \
|
||||
--port 8180 \
|
||||
@@ -37,14 +32,11 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--limit-mm-per-prompt '{"image": 100, "video": 100}' \
|
||||
--reasoning-parser ernie-45-vl \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--enable-chunked-prefill \
|
||||
--max-num-batched-tokens 384 \
|
||||
--quantization wint4 \
|
||||
--enable-mm
|
||||
--quantization wint4
|
||||
```
|
||||
**示例2:** H800上双卡部署128K上下文的服务
|
||||
```shell
|
||||
export ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model baidu/ERNIE-4.5-VL-28B-A3B-Paddle \
|
||||
--port 8180 \
|
||||
@@ -56,12 +48,9 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--limit-mm-per-prompt '{"image": 100, "video": 100}' \
|
||||
--reasoning-parser ernie-45-vl \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--enable-chunked-prefill \
|
||||
--max-num-batched-tokens 384 \
|
||||
--quantization wint4 \
|
||||
--enable-mm
|
||||
--quantization wint4
|
||||
```
|
||||
> ⚠️ 2.1及以上版本需要通过环境变量开启新调度器 `ENABLE_V1_KVCACHE_SCHEDULER=1`,否则可能会有部分请求最大长度前截断或返空。
|
||||
|
||||
示例是可以稳定运行的一组配置,同时也能得到比较好的性能。
|
||||
如果对精度、性能有进一步的要求,请继续阅读下面的内容。
|
||||
@@ -91,9 +80,9 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
|
||||
#### 2.2.2 Chunked Prefill
|
||||
- **参数:** `--enable-chunked-prefill`
|
||||
- **用处:** 开启 `chunked prefill` 可**降低显存峰值**并**提升服务吞吐**。
|
||||
- **用处:** 开启 `chunked prefill` 可降低显存峰值并提升服务吞吐。2.2版本已经**默认开启**,2.2之前需要手动开启,参考2.1的最佳实践文档。
|
||||
|
||||
- **其他相关配置**:
|
||||
- **相关配置**:
|
||||
|
||||
`--max-num-batched-tokens`:限制每个chunk的最大token数量。多模场景下每个chunk会向上取整保持图片的完整性,因此实际每次推理的总token数会大于该值。我们推荐设置为384。
|
||||
|
||||
@@ -115,12 +104,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
- **描述**:拒绝采样即从一个易于采样的提议分布(proposal distribution)中生成样本,避免显式排序从而达到提升采样速度的效果,可以提升推理性能。
|
||||
- **推荐**:这是一种影响效果的较为激进的优化策略,我们还在全面验证影响。如果对性能有较高要求,也可以接受对效果的影响时可以尝试开启。
|
||||
|
||||
> **Attention超参:**`FLAGS_max_partition_size=1024`
|
||||
- **描述**:Append Attntion(默认)后端的超参,我们在常用数据集上的测试结果表明,设置为1024后可以大幅提升解码速度,尤其是长文场景。
|
||||
- **推荐**:未来会修改为自动调整的机制。如果对性能有较高要求可以尝试开启。
|
||||
|
||||
## 三、常见问题FAQ
|
||||
**注意:** 使用多模服务部署需要在配置中添加参数 `--enable-mm`。
|
||||
|
||||
### 3.1 显存不足(OOM)
|
||||
如果服务启动时提示显存不足,请尝试以下方法:
|
||||
|
@@ -15,10 +15,6 @@
|
||||
|
||||
安装流程参考文档 [FastDeploy GPU 安装](../get_started/installation/nvidia_gpu.md)
|
||||
|
||||
> ⚠️ 注意事项
|
||||
> - FastDeploy只支持Paddle格式的模型,注意下载Paddle后缀的模型
|
||||
> - 使用模型名称会自动下载模型,如果已经下载过模型,可以直接使用模型下载位置的绝对路径
|
||||
|
||||
## 二、如何使用
|
||||
### 2.1 基础:启动服务
|
||||
**示例1:** H800上8卡部署128K上下文的服务
|
||||
@@ -33,13 +29,10 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--max-num-seqs 16 \
|
||||
--limit-mm-per-prompt '{"image": 100, "video": 100}' \
|
||||
--reasoning-parser ernie-45-vl \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--enable-chunked-prefill \
|
||||
--gpu-memory-utilization 0.85 \
|
||||
--max-num-batched-tokens 384 \
|
||||
--quantization wint4 \
|
||||
--enable-mm
|
||||
--quantization wint4
|
||||
```
|
||||
> ⚠️ 2.1及以上版本需要通过环境变量开启新调度器 `ENABLE_V1_KVCACHE_SCHEDULER=1`,否则可能会有部分请求最大长度前截断或返空。
|
||||
|
||||
示例是可以稳定运行的一组配置,同时也能得到比较好的性能。
|
||||
如果对精度、性能有进一步的要求,请继续阅读下面的内容。
|
||||
@@ -68,9 +61,9 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
|
||||
#### 2.2.2 Chunked Prefill
|
||||
- **参数:** `--enable-chunked-prefill`
|
||||
- **用处:** 开启 `chunked prefill` 可**降低显存峰值**并**提升服务吞吐**。
|
||||
- **用处:** 开启 `chunked prefill` 可降低显存峰值并提升服务吞吐。2.2版本已经**默认开启**,2.2之前需要手动开启,参考2.1的最佳实践文档。
|
||||
|
||||
- **其他相关配置**:
|
||||
- **相关配置**:
|
||||
|
||||
`--max-num-batched-tokens`:限制每个chunk的最大token数量。多模场景下每个chunk会向上取整保持图片的完整性,因此实际每次推理的总token数会大于该值。推荐设置为384。
|
||||
|
||||
@@ -92,10 +85,6 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
- **描述**:拒绝采样即从一个易于采样的提议分布(proposal distribution)中生成样本,避免显式排序从而达到提升采样速度的效果,可以提升推理性能。
|
||||
- **推荐**:这是一种影响效果的较为激进的优化策略,我们还在全面验证影响。如果对性能有较高要求,也可以接受对效果的影响时可以尝试开启。
|
||||
|
||||
> **Attention超参:**`FLAGS_max_partition_size=1024`
|
||||
- **描述**:Append Attntion(默认)后端的超参,我们在常用数据集上的测试结果表明,设置为1024后可以大幅提升解码速度,尤其是长文场景。
|
||||
- **推荐**:未来会修改为自动调整的机制。如果对性能有较高要求可以尝试开启。
|
||||
|
||||
## 三、常见问题FAQ
|
||||
**注意:** 使用多模服务部署需要在配置中添加参数 `--enable-mm`。
|
||||
|
||||
|
166
docs/zh/features/data_parallel_service.md
Normal file
@@ -0,0 +1,166 @@
|
||||
# 数据并行
|
||||
在MOE模型下,开启专家并行(EP)与数据并行(DP)相结合,EP 分摊专家负载,结合 DP 实现请求并行处理。
|
||||
|
||||
## 数据分发策略
|
||||
FastDeploy 通过splitwise scheduler 感知各个DP的负载状态,对接收到数据进行分发。
|
||||
|
||||
splitwise scheduler 依赖redis存储各个DP的负载状态,对接收到的数据进行分发。
|
||||
|
||||
### 专家并行 + 混合式部署
|
||||
|
||||
FastDeploy 提供了splitwise scheduler,可以感知各个DP的负载状态,对接收到的数据进行调度。
|
||||
具体调度流程如下图,用户随机请求ip 与端口,通过redis获取负载状态,将数据分发到负载较低的DP进行推理。
|
||||

|
||||
|
||||
|
||||
#### 离线推理
|
||||
```python
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"你好,请问今天是星期",
|
||||
"请写6个以数字开头的成语",
|
||||
"写一个300字的小说大纲,内容是李白穿越到现代,最后成为公司文职人员的故事",
|
||||
"我要采访一位科幻作家,创建一个包含5个问题的列表"
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=128)
|
||||
|
||||
llm = LLM(
|
||||
model="ERNIE-4_5-300B-A47B-FP8-Paddle",
|
||||
tensor_parallel_size=1,
|
||||
data_parallel_size=8,
|
||||
max_model_len=8192,
|
||||
num_gpu_blocks_override=1024,
|
||||
engine_worker_queue_port="6077,6078,6079,6080,6081,6082,6083,6084",
|
||||
enable_expert_parallel=True,
|
||||
scheduler_name="splitwise",
|
||||
scheduler_host="127.0.0.1",
|
||||
scheduler_topic="test",
|
||||
scheduler_port=6379
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs.text
|
||||
print("generated_text: ", generated_text)
|
||||
print("\n")
|
||||
|
||||
|
||||
```
|
||||
|
||||
#### 在线推理
|
||||
```shell
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ERNIE-4_5-300B-A47B-FP8-Paddle \
|
||||
--port 8184 --metrics-port 8185 \
|
||||
--engine-worker-queue-port "6077,6078,6079,6080,6081,6082,6083,6084" \
|
||||
--data-parallel-size 8 --tensor-parallel-size 1\
|
||||
--enable-expert-parallel \
|
||||
--scheduler-name "splitwise" \
|
||||
--scheduler-host "127.0.0.1" \
|
||||
--scheduler-port 6379 \
|
||||
--scheduler-topic "test" \
|
||||
--scheduler-ttl 9000
|
||||
```
|
||||
|
||||
|
||||
### 用户自行调度
|
||||
FastDeploy 提供了multi_api_server,用户可以拉起多个api server,用户自行选择dp 进行请求,在该种情况下用户可以自行添加负载均衡模型进行调度。(目前该种方式只支持在线推理)
|
||||
|
||||
|
||||
#### 在线推理
|
||||
|
||||

|
||||
|
||||
```shell
|
||||
export FD_ENABLE_MULTI_API_SERVER=1
|
||||
python -m fastdeploy.entrypoints.openai.multi_api_server \
|
||||
--ports "1811,1822,1833,1844,1855,1866,1877,1888" \
|
||||
--num-servers 8 \
|
||||
--metrics-ports "3101,3201,3301,3401,3501,3601,3701,3801" \
|
||||
--args --model ERNIE-4_5-300B-A47B-FP8-Paddle \
|
||||
--engine-worker-queue-port "25611,25621,25631,25641,25651,25661,25671,25681" \
|
||||
--tensor-parallel-size 1 \
|
||||
--data-parallel-size 8 \
|
||||
--max-model-len 12288 \
|
||||
--max-num-seqs 64 \
|
||||
--num-gpu-blocks-override 256 \
|
||||
--enable-expert-parallel
|
||||
```
|
||||
|
||||
### 参数说明
|
||||
- num-servers: 指定拉起的api server 的数量
|
||||
- ports: 指定拉起的api server 的端口
|
||||
- args: 指定拉起的api server 的参数
|
||||
|
||||
|
||||
|
||||
### 数据并行 + 分离式部署
|
||||
|
||||
具体可以参考[分离式部署](disaggregated.md#多机分离式部署)
|
||||
|
||||
#### 在线推理
|
||||
|
||||
多机部署时需要确认当前网卡是否支持RDMA,并且需要集群中所有节点网络互通。
|
||||
|
||||
**注意**:
|
||||
* `KVCACHE_RDMA_NICS` 指定当前机器的RDMA网卡,多个网卡用逗号隔开。
|
||||
* 仓库中提供了自动检测RDMA网卡的脚本 `bash scripts/get_rdma_nics.sh <device>`, 其中 <device> 可以是 `cpu` 或 `gpu`。
|
||||
|
||||
**prefill 实例**
|
||||
|
||||
```bash
|
||||
export FD_LOG_DIR="log_prefill"
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
echo "set RDMA NICS"
|
||||
export $(bash scripts/get_rdma_nics.sh gpu)
|
||||
echo "KVCACHE_RDMA_NICS ${KVCACHE_RDMA_NICS}"
|
||||
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ERNIE-4_5-300B-A47B-FP8-Paddle \
|
||||
--port 8180 --metrics-port 8181 \
|
||||
--engine-worker-queue-port "25611,25621,25631,25641,25651,25661,25671,25681" \
|
||||
--cache-queue-port 8183 \
|
||||
--tensor-parallel-size 1 \
|
||||
--data-parallel-size 4 \
|
||||
--enable-expert-parallel \
|
||||
--cache-transfer-protocol "rdma,ipc" \
|
||||
--rdma-comm-ports "7671,7672,7673,7674,7675,7676,7677,7678" \
|
||||
--pd-comm-port "2334" \
|
||||
--splitwise-role "prefill" \
|
||||
--scheduler-name "splitwise" \
|
||||
--scheduler-host "127.0.0.1" \
|
||||
--scheduler-port 6379 \
|
||||
--scheduler-topic "test" \
|
||||
--scheduler-ttl 9000
|
||||
```
|
||||
|
||||
**decode 实例**
|
||||
|
||||
```bash
|
||||
export FD_LOG_DIR="log_decode"
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
echo "set RDMA NICS"
|
||||
export $(bash scripts/get_rdma_nics.sh gpu)
|
||||
echo "KVCACHE_RDMA_NICS ${KVCACHE_RDMA_NICS}"
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ERNIE-4_5-300B-A47B-FP8-Paddle \
|
||||
--port 8184 --metrics-port 8185 \
|
||||
--engine-worker-queue-port "25611,25621,25631,25641,25651,25661,25671,25681" \
|
||||
--cache-queue-port 8187 \
|
||||
--tensor-parallel-size 1 \
|
||||
--data-parallel-size 4 \
|
||||
--enable-expert-parallel \
|
||||
--scheduler-name "splitwise" \
|
||||
--cache-transfer-protocol "rdma,ipc" \
|
||||
--rdma-comm-ports "7671,7672,7673,7674,7675,7676,7677,7678" \
|
||||
--pd-comm-port "2334" \
|
||||
--scheduler-host "127.0.0.1" \
|
||||
--scheduler-port 6379 \
|
||||
--scheduler-ttl 9000
|
||||
--scheduler-topic "test" \
|
||||
--splitwise-role "decode"
|
||||
```
|
||||
|
@@ -75,6 +75,10 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
#### 前置依赖 Redis
|
||||
* 使用`conda`安装
|
||||
|
||||
> **⚠️ 注意**
|
||||
> **Redis 版本要求:6.2.0 及以上**
|
||||
> 低于此版本可能不支持所需的命令。
|
||||
|
||||
```bash
|
||||
# 安装
|
||||
conda install redis
|
||||
@@ -106,13 +110,17 @@ sudo systemctl start redis
|
||||
|
||||
**注意**:
|
||||
* `KVCACHE_RDMA_NICS` 指定当前机器的RDMA网卡,多个网卡用逗号隔开。
|
||||
* 仓库中提供了自动检测RDMA网卡的脚本 `bash scripts/get_rdma_nics.sh <device>`, 其中 <device> 可以是 `cpu` 或 `gpu`。
|
||||
|
||||
**prefill 实例**
|
||||
|
||||
```bash
|
||||
|
||||
export FD_LOG_DIR="log_prefill"
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
export KVCACHE_RDMA_NICS="mlx5_2,mlx5_3,mlx5_4,mlx5_5"
|
||||
echo "set RDMA NICS"
|
||||
export $(bash scripts/get_rdma_nics.sh gpu)
|
||||
echo "KVCACHE_RDMA_NICS ${KVCACHE_RDMA_NICS}"
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ERNIE-4.5-300B-A47B-BF16 \
|
||||
--port 8180 --metrics-port 8181 \
|
||||
@@ -127,6 +135,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--scheduler-name "splitwise" \
|
||||
--scheduler-host "127.0.0.1" \
|
||||
--scheduler-port 6379 \
|
||||
--scheduler-topic "test" \
|
||||
--scheduler-ttl 9000
|
||||
```
|
||||
|
||||
@@ -135,7 +144,9 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
```bash
|
||||
export FD_LOG_DIR="log_decode"
|
||||
export CUDA_VISIBLE_DEVICES=4,5,6,7
|
||||
export KVCACHE_RDMA_NICS="mlx5_2,mlx5_3,mlx5_4,mlx5_5"
|
||||
echo "set RDMA NICS"
|
||||
export $(bash scripts/get_rdma_nics.sh gpu)
|
||||
echo "KVCACHE_RDMA_NICS ${KVCACHE_RDMA_NICS}"
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model ERNIE-4.5-300B-A47B-BF16 \
|
||||
--port 8184 --metrics-port 8185 \
|
||||
@@ -150,6 +161,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--scheduler-host "127.0.0.1" \
|
||||
--scheduler-port 6379 \
|
||||
--scheduler-ttl 9000
|
||||
--scheduler-topic "test" \
|
||||
--splitwise-role "decode"
|
||||
```
|
||||
|
||||
@@ -168,5 +180,6 @@ python -m fastdeploy.entrypoints.openai.api_server \
|
||||
* --scheduler-host: 连接的redis地址
|
||||
* --scheduler-port: 连接的redis端口
|
||||
* --scheduler-ttl: 指定redis的ttl时间,单位为秒
|
||||
* --scheduler-topic: 指定redis的topic
|
||||
* --pd-comm-port: 指定pd通信的端口
|
||||
* --rdma-comm-ports: 指定RDMA通信的端口,多个端口用逗号隔开,数量与卡数一致
|
||||
|
BIN
docs/zh/features/images/no_scheduler_img.png
Normal file
After Width: | Height: | Size: 118 KiB |
BIN
docs/zh/features/images/plas_inference_union.png
Normal file
After Width: | Height: | Size: 34 KiB |
BIN
docs/zh/features/images/plas_training_distill.png
Normal file
After Width: | Height: | Size: 80 KiB |
BIN
docs/zh/features/images/scheduler_img.png
Normal file
After Width: | Height: | Size: 105 KiB |
223
docs/zh/features/plas_attention.md
Normal file
@@ -0,0 +1,223 @@
|
||||
# PLAS
|
||||
|
||||
## 介绍
|
||||
|
||||
我们提出了**PLAS(Pluggable Lightweight Attention for Sparsity)**,这是对 MoBA 的改进。具体来说,我们采用了受 MoE 启发的结构,将 KV 划分为多个块,并引入了一个可学习的 MLP 层来自适应地选择重要块。PLAS 可以直接在训练后应用,此时只有 MLP 权重可学习,而原始模型权重保持不变。
|
||||
|
||||
与 NSA/MoBA 相比,我们的 PLAS 具有更高的可扩展性和可插拔性。它无需修改传统的注意力架构,也无需在训练前或训练后干扰模型权重训练。最终阶段只需对 MLP 层进行少量训练即可实现几乎无损的准确率。由于 NSA/MoBA 会更新整个模型权重,因此不可避免地会影响短文本的性能——即使它在输入长度小于 BlockSize × Top-K 时会自动切换到完全注意力机制。相比之下,我们的 PLAS 在短文本场景下可以实现与原始模型真正等同的完全注意力机制。
|
||||
|
||||
在训练效率方面,由于仅需更新 MLP 权重,训练成本极低。在推理性能方面,当输入长度为 128K、Block Size = 128、Top-K = 55 时,PLAS 相比 Flash Attention 3 实现了**386% 的加速**。
|
||||
|
||||
## 方法
|
||||
|
||||
### 训练
|
||||
|
||||
借鉴 NSA 和 MoBA 的方法,我们将键值对 (KV) 划分为多个块。在预填充和解码阶段,我们不再对所有键值进行注意力计算,而是动态地为每个查询 token 选择注意力得分最高的前 K 个块,从而实现高效的稀疏注意力计算。
|
||||
|
||||
<div align="center">
|
||||
<img src="images/plas_training_distill.png" alt="Attention Gate Module" width="60%">
|
||||
</div>
|
||||
|
||||
* **Attention Gate Module**: 如上图所示,为了以较低的计算开销估计每个块的重要性,我们设计了一个轻量级的注意力门模块。该模块首先通过一个MLP层压缩每个K个块,生成一个具有代表性的低维表示: $K_c^T=W_{kp}K^T$ ,其中 $W_{kp}$ 表示 MLP 层的权重。与直接应用均值池化相比,可学习的 MLP 可以更有效地捕捉不同 token 之间的语义关系和重要性分布,从而提供每个块的精细表示。在获得压缩表示 $K_c$ 之后,通过以下公式估计每个查询 token 相对于每个块的重要性:$Softmax(Q\cdot K_c^T)$。为了增强 MLP 层的判别能力,我们使用一维最大池化后的完整注意力结果 $1DMaxPooling(Softmax(Q \cdot K^T))$ 作为 ground truth。通过最小化两者之间的分布差异,引导 MLP 层学习更符合真实注意力分布的特征表示。
|
||||
|
||||
* **Training Data**: 得益于模型架构和训练范式的高效性,我们的方法仅使用 10 亿个 token 进行训练,便实现了近乎无损的精度。训练数据源自内部构建的包含长文本和短文本的混合语料库,从而增强了模块对不同序列长度的适应性。
|
||||
|
||||
* **Other**: 我们观察到,最终的解码层对模型整体准确率有显著影响。因此,在训练过程中,我们将该层排除在稀疏注意力计算之外,并在推理过程中将其恢复为完全注意力。
|
||||
|
||||
### 推理优化
|
||||
|
||||
在稀疏注意力计算过程中,每个查询 token 可能会动态选择不同的 KV 块,导致 HBM 的内存访问模式非常不规则。简单地对每个查询 token 进行单独处理是可行的,但这会导致计算粒度过细,无法充分利用张量核,从而显著降低 GPU 的计算效率。
|
||||
|
||||
<div align="center">
|
||||
<img src="images/plas_inference_union.png" alt="Token/Head Union" width="60%">
|
||||
</div>
|
||||
|
||||
为了优化预填充和解码阶段的性能,我们设计了一种特殊的联合策略来适应各自的特点:
|
||||
|
||||
* **Prefill Toke Union**: 我们观察到相邻的查询标记倾向于选择相似的关键块。利用这种局部性,我们取连续 128 个查询标记选择的关键块的并集,并联合计算这些标记的稀疏注意力机制。
|
||||
|
||||
* **Decode Head Union**: 鉴于GQA在现代模型中的广泛应用,我们发现同一组内的不同查询头经常选择重叠的关键块。因此,我们将同一组内所有查询头选择的关键块合并为一个统一的集合,并联合计算稀疏注意力机制。这种方式也减少了内存访问开销,并进一步提高了解码效率。
|
||||
|
||||
* **Top-K Selection**: 传统的 Top-k 算法基于排序或直接调用 Cub 库,会带来显著的运行时开销。为了缓解这个问题,我们实现了一个基于二分查找的近似 Top-k 选择算法,该算法在保持准确率的同时显著降低了延迟,最终实现了性能的显著提升。
|
||||
|
||||
## 评估
|
||||
|
||||
### 实验
|
||||
|
||||
我们在 LongBenchV2 和 Ruler(上下文长度分别为 32K、64K 和 128K)上评估了全注意力和稀疏注意力的精度。
|
||||
|
||||
<table style="border-collapse: collapse; width: 100%;">
|
||||
<tr>
|
||||
<td rowspan="4" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>Model</strong>
|
||||
</td>
|
||||
<td colspan="8" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>Precision</strong>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="4" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>FullAttention</strong>
|
||||
</td>
|
||||
<td colspan="4" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>SparseAttention</strong>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td rowspan="2" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>LongBenchV2</strong>
|
||||
</td>
|
||||
<td colspan="3" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>Ruler</strong>
|
||||
</td>
|
||||
<td rowspan="2" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>LongBenchV2</strong>
|
||||
</td>
|
||||
<td colspan="3" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>Ruler</strong>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>32K</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>64K</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>128K</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>32K</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>64K</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>128K</strong>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>ERNIE-4.5-21B-A3B</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">31.48</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">76.74</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">56.40</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">25.48</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">31.45</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">75.93</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">55.38</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">25.05</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>ERNIE-4.5-300B-A47B</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">41.02</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">94.70</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">83.56</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">58.18</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">41.05</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">94.50</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">82.32</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">57.85</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### 性能
|
||||
|
||||
我们从 InfiniteBench 中选择了一个子集 (longbook_sum_eng) 作为性能评估数据集。对于长度超过 128K 的输入,我们截断序列,保留前 64K 和后 64K 个 token。
|
||||
|
||||
<table style="border-collapse: collapse; width: 100%;">
|
||||
<tr>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;"></td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;"></td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;"><strong>QPS</strong></td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;"><strong>Decode Speed (token/s)</strong></td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;"><strong>Time to First token(s)</strong></td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;"><strong>Time per Ouput Token(ms)</strong></td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;"><strong>End-to-End Latency(s)</strong></td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;"><strong>Mean Input<br>Length</strong></td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;"><strong>Mean Output Length</strong></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td rowspan="2" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>ERNIE-4.5-21B-A3B</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>FullAttention</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">0.101</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">13.32</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">8.082</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">87.05</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">61.400</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">113182.32</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">627.76</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>SparseAttention</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">0.150(+48%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">18.12(+36%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">5.466(-48%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">66.35(-31%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">42.157(-46%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">113182.32</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">590.23</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td rowspan="2" style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>ERNIE-4.5-300B-A47B</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>FullAttention</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">0.066</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">5.07</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">13.812</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">206.70</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">164.704</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">113182.32</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">725.97</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">
|
||||
<strong>SparseAttention</strong>
|
||||
</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">0.081(+23%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">6.75(+33%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">10.584(-30%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">154.84(-34%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">132.745(-24%)</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">113182.32</td>
|
||||
<td style="border: 1px solid #dcdde0; padding: 8px; text-align: center; vertical-align: middle;">748.25</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## 使用方式
|
||||
|
||||
```
|
||||
export FD_ATTENTION_BACKEND="PLAS_ATTN"
|
||||
|
||||
python -m fastdeploy.entrypoints.openai.api_server
|
||||
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
|
||||
--port 8188 \
|
||||
--tensor-parallel-size 4 \
|
||||
--quantization wint4 \
|
||||
--enable-chunked-prefill \
|
||||
--max-num-batched-tokens 8192 \
|
||||
--max-model-len 131072 \
|
||||
--max-num-seqs 32 \
|
||||
--plas-attention-config '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}'
|
||||
```
|
||||
|
||||
**Note**: 如果启用了稀疏注意力机制,系统将自动从权重目录中的`plas_attention_mlp_weight.safetensors`文件加载 MLP 权重。如果未找到 MLP 权重文件,则将对关键表示应用均值池化
|
||||
|
||||
**Parameter Description:**
|
||||
|
||||
* `FD_ATTENTION_BACKEND="PLAS_ATTN"` 启用 PLAS sparse attention.
|
||||
* `plas_encoder_top_k_left=50, plas_encoder_top_k_right=60` 表示当encoder时,top-k的范围在50到60之间。
|
||||
* `plas_decoder_top_k_left=100, plas_decoder_top_k_right=120` 表示当decoder时,top-k的范围在100到120之间。
|
@@ -1,6 +1,6 @@
|
||||
# 采样策略
|
||||
|
||||
采样策略用于决定如何从模型的输出概率分布中选择下一个token。FastDeploy目前支持 Top-p 、 Top-k_Top-p 和 Min-p Samping 多种采样策略。
|
||||
采样策略用于决定如何从模型的输出概率分布中选择下一个token。FastDeploy目前支持 Top-p 、 Top-k_Top-p 和 Min-p Sampling 多种采样策略。
|
||||
|
||||
1. Top-p 采样
|
||||
|
||||
|