mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
[GCU] Support gcu platform (#2702)
baseline: e7fa57ebae
Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
8
build.sh
8
build.sh
@@ -113,6 +113,14 @@ function copy_ops(){
|
||||
return
|
||||
fi
|
||||
|
||||
is_gcu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('gcu'))"`
|
||||
if [ "$is_gcu" = "True" ]; then
|
||||
DEVICE_TYPE="gcu"
|
||||
cp -r ${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gcu
|
||||
echo -e "gcu ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
DEVICE_TYPE="cpu"
|
||||
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cd ../../../../
|
||||
|
@@ -501,6 +501,17 @@ elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||
],
|
||||
),
|
||||
)
|
||||
elif paddle.is_compiled_with_custom_device("gcu"):
|
||||
setup(
|
||||
name="fastdeploy_ops",
|
||||
ext_modules=CppExtension(
|
||||
sources=[
|
||||
"gpu_ops/save_with_output_msg.cc",
|
||||
"gpu_ops/get_output.cc",
|
||||
"gpu_ops/get_output_msg_with_topk.cc",
|
||||
]
|
||||
),
|
||||
)
|
||||
else:
|
||||
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
|
||||
|
||||
|
@@ -1,8 +1,8 @@
|
||||
# Running ERNIE-4.5-21B-A3B with FastDeploy
|
||||
# Running ERNIE 4.5 Series Models with FastDeploy
|
||||
|
||||
The Enflame S60 ([Learn about Enflame](https://www.enflame-tech.com/)) is a next-generation AI inference accelerator card designed for large-scale deployment in data centers. It meets the demands of large language models (LLMs), search/advertising/recommendation systems, and traditional models. Characterized by broad model coverage, user-friendliness, and high portability, it is widely applicable to mainstream inference scenarios such as image and text generation applications, search and recommendation systems, and text/image/speech recognition.
|
||||
|
||||
FastDeploy has deeply adapted and optimized the ernie-4_5-21b-a3b-bf16-paddle model for the Enflame S60, achieving a unified inference interface between GCU and GPU. This allows seamless migration of inference tasks without code modifications.
|
||||
FastDeploy has deeply adapted and optimized the ERNIE 4.5 Series Models for the Enflame S60, achieving a unified inference interface between GCU and GPU. This allows seamless migration of inference tasks without code modifications.
|
||||
|
||||
## 🚀 Quick Start 🚀
|
||||
|
||||
@@ -27,15 +27,15 @@ lspci | grep S60
|
||||
3b:00.0 Processing accelerators: Shanghai Enflame Technology Co. Ltd S60 [Enflame] (rev 01)
|
||||
3c:00.0 Processing accelerators: Shanghai Enflame Technology Co. Ltd S60 [Enflame] (rev 01)
|
||||
```
|
||||
### 1. Environment Setup (Estimated time: 5–10 minutes)
|
||||
### 1. Environment Setup (Estimated time: 5-10 minutes)
|
||||
1. Pull the Docker image
|
||||
```bash
|
||||
# Note: This image only contains the Paddle development environment, not precompiled PaddlePaddle packages
|
||||
docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.4.623-ubuntu20-x86_64-gcc84
|
||||
docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84
|
||||
```
|
||||
2. Start the container
|
||||
```bash
|
||||
docker run --name paddle-gcu-llm -v /home:/home -v /work:/work --network=host --ipc=host -it --privileged ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.4.623-ubuntu20-x86_64-gcc84 /bin/bash
|
||||
docker run --name paddle-gcu-llm -v /home:/home -v /work:/work --network=host --ipc=host -it --privileged ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84 /bin/bash
|
||||
```
|
||||
3. Obtain and install drivers<br/>
|
||||
**Full software packages are preloaded in the Docker container. Copy them to an external directory, e.g., ```/home/workspace/deps/```**
|
||||
@@ -67,25 +67,31 @@ python -m pip install paddle-custom-gcu==3.1.0 -i https://www.paddlepaddle.org.c
|
||||
7. Install FastDeploy and dependencies
|
||||
```bash
|
||||
python -m pip install fastdeploy -i https://www.paddlepaddle.org.cn/packages/stable/gcu/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simplels
|
||||
apt install python3.10-distutils
|
||||
# For source compilation, refer to the following steps
|
||||
git clone https://github.com/PaddlePaddle/FastDeploy
|
||||
cd FastDeploy
|
||||
python -m pip install -r requirements.txt --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simplels
|
||||
bash build.sh 1
|
||||
```
|
||||
### 2. Data Preparation (Estimated time: 2–5 minutes)
|
||||
### 2. Data Preparation (Estimated time: 2-5 minutes)
|
||||
Use a trained model for inference on GSM8K dataset:
|
||||
```bash
|
||||
mkdir -p /home/workspace/benchmark/ && cd /home/workspace/benchmark/
|
||||
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
||||
```
|
||||
Place model weights in a directory, e.g., ```/work/models/ernie-4_5-21b-a3b-bf16-paddle/```
|
||||
### 3. Inference (Estimated time: 2–5 minutes)
|
||||
Place model weights in a directory, e.g., ```/work/models/ERNIE-4.5-300B-A47B-Paddle/```
|
||||
### 3. Inference (Estimated time: 2-5 minutes)
|
||||
Start the inference service:
|
||||
```bash
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model "/work/models/ernie-4_5-21b-a3b-bf16-paddle/" \
|
||||
--model "/work/models/ERNIE-4.5-300B-A47B-Paddle/" \
|
||||
--port 8188 \
|
||||
--metrics-port 8200 \
|
||||
--tensor-parallel-size 4 \
|
||||
--max-model-len 8192 \
|
||||
--num-gpu-blocks-override 1024
|
||||
--tensor-parallel-size 8 \
|
||||
--max-model-len 32768 \
|
||||
--num-gpu-blocks-override 4096 \
|
||||
--max-num-batched-tokens 32768 \
|
||||
--quantization "wint4"
|
||||
```
|
||||
Query the model service:
|
||||
```bash
|
||||
@@ -93,13 +99,13 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": "The largest ocean is"}
|
||||
{"role": "user", "content": "Where is Beijing?"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
Successful execution returns inference results, e.g.:
|
||||
```json
|
||||
{"id":"chatcmpl-5cd96f3b-eff3-4dc0-8aa2-8b5d7b7b86f2","object":"chat.completion","created":1751167862,"model":"default","choices":[{"index":0,"message":{"role":"assistant","content":"3. **Pacific Ocean**: The Pacific Ocean is the largest and deepest of the world's oceans. It covers an area of approximately 181,344,000 square kilometers, which is more than 30% of the Earth's surface. It is located between the Americas to the west and east, and Asia and Australia to the north and south. The Pacific Ocean is known for its vastness, diverse marine life, and numerous islands.\n\nIn summary, the largest ocean in the world is the Pacific Ocean.","reasoning_content":null,"tool_calls":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"total_tokens":127,"completion_tokens":116,"prompt_tokens_details":{"cached_tokens":0}}}
|
||||
{"id":"chatcmpl-20f1210d-6943-4110-ad2d-c76ba11604ad","object":"chat.completion","created":1751621261,"model":"default","choices":[{"index":0,"message":{"role":"assistant","content":"Beijing is the capital city of the People's Republic of China, located in the northern part of the country. It is situated in the North China Plain, bordered by the mountains to the west, north, and northeast. Beijing serves as China's political, cultural, and international exchange center, playing a crucial role in the nation's development and global interactions.","reasoning_content":null,"tool_calls":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"total_tokens":88,"completion_tokens":77,"prompt_tokens_details":{"cached_tokens":0}}}
|
||||
```
|
||||
### 4. Accuracy Testing (Estimated time: 60–180 minutes)
|
||||
Place the accuracy script ```bench_gsm8k.py``` in ```/home/workspace/benchmark/``` and modify sampling parameters, e.g.:
|
||||
@@ -120,10 +126,10 @@ data = {
|
||||
Run accuracy tests:
|
||||
```bash
|
||||
cd /home/workspace/benchmark/
|
||||
python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 2
|
||||
python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 8
|
||||
```
|
||||
Upon completion, accuracy results are saved in ```result.jsonl```, e.g.:
|
||||
```json
|
||||
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 365.548, "accuracy": 0.967, "num_requests": 30, "other": {"num_questions": 30, "parallel": 2}}
|
||||
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}}
|
||||
```
|
||||
|
||||
|
@@ -1,8 +1,8 @@
|
||||
# 使用 FastDeploy 在燧原 S60 上运行 ERNIE-4.5-21B-A3B模型
|
||||
# 使用 FastDeploy 在燧原 S60 上运行 ERNIE 4.5 系列模型
|
||||
|
||||
燧原 S60([了解燧原](https://www.enflame-tech.com/))是面向数据中心大规模部署的新一代人工智能推理加速卡,满足大语言模型、搜广推及传统模型的需求,具有模型覆盖面广、易用性强、易迁移易部署等特点,可广泛应用于图像及文本生成等应用、搜索与推荐、文本、图像及语音识别等主流推理场景。
|
||||
|
||||
FastDeploy 在燧原 S60 上对 ernie-4_5-21b-a3b-bf16-paddle 模型进行了深度适配和优化,实现了 GCU 推理入口和 GPU 的统一,无需修改即可完成推理任务的迁移。
|
||||
FastDeploy 在燧原 S60 上对 ERNIE 4.5 系列模型进行了深度适配和优化,实现了 GCU 推理入口和 GPU 的统一,无需修改即可完成推理任务的迁移。
|
||||
|
||||
## 🚀 快速开始 🚀
|
||||
|
||||
@@ -30,11 +30,11 @@ lspci | grep S60
|
||||
1. 拉取镜像
|
||||
```bash
|
||||
# 注意此镜像仅为paddle开发环境,镜像中不包含预编译的飞桨安装包
|
||||
docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.4.623-ubuntu20-x86_64-gcc84
|
||||
docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84
|
||||
```
|
||||
2. 参考如下命令启动容器
|
||||
```bash
|
||||
docker run --name paddle-gcu-llm -v /home:/home -v /work:/work --network=host --ipc=host -it --privileged ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.4.623-ubuntu20-x86_64-gcc84 /bin/bash
|
||||
docker run --name paddle-gcu-llm -v /home:/home -v /work:/work --network=host --ipc=host -it --privileged ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84 /bin/bash
|
||||
```
|
||||
3. 获取并安装驱动<br/>
|
||||
**docker 内提前放置了全量软件包,需拷贝至 docker 外目录,如:```/home/workspace/deps/```**
|
||||
@@ -63,10 +63,14 @@ python -m pip install paddlepaddle==3.1.0a0 -i https://www.paddlepaddle.org.cn/p
|
||||
python -m pip install paddle-custom-gcu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/gcu/
|
||||
# 如想源码编译安装,请参考https://github.com/PaddlePaddle/PaddleCustomDevice/blob/develop/backends/gcu/README_cn.md
|
||||
```
|
||||
7. 安装 FastDeploy 和 依赖<br/>
|
||||
7. 安装 FastDeploy <br/>
|
||||
```bash
|
||||
python -m pip install fastdeploy -i https://www.paddlepaddle.org.cn/packages/stable/gcu/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simplels
|
||||
apt install python3.10-distutils
|
||||
# 如想源码编译安装,请参考如下步骤
|
||||
git clone https://github.com/PaddlePaddle/FastDeploy
|
||||
cd FastDeploy
|
||||
python -m pip install -r requirements.txt --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simplels
|
||||
bash build.sh 1
|
||||
```
|
||||
### 2. 数据准备:(这将花费您 2~5min 时间)
|
||||
使用训练好的模型,在 GSM8K 上推理
|
||||
@@ -74,17 +78,19 @@ apt install python3.10-distutils
|
||||
mkdir -p /home/workspace/benchmark/ && cd /home/workspace/benchmark/
|
||||
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
||||
```
|
||||
准备模型和权重,置于环境目录,如:```/work/models/ernie-4_5-21b-a3b-bf16-paddle/```
|
||||
准备模型和权重,置于环境目录,如:```/work/models/ERNIE-4.5-300B-A47B-Paddle/```
|
||||
### 3. 推理:(这将花费您 2~5min 时间)
|
||||
执行如下命令启动推理服务
|
||||
```bash
|
||||
python -m fastdeploy.entrypoints.openai.api_server \
|
||||
--model "/work/models/ernie-4_5-21b-a3b-bf16-paddle/" \
|
||||
--model "/work/models/ERNIE-4.5-300B-A47B-Paddle/" \
|
||||
--port 8188 \
|
||||
--metrics-port 8200 \
|
||||
--tensor-parallel-size 4 \
|
||||
--max-model-len 8192 \
|
||||
--num-gpu-blocks-override 1024
|
||||
--tensor-parallel-size 8 \
|
||||
--max-model-len 32768 \
|
||||
--num-gpu-blocks-override 4096 \
|
||||
--max-num-batched-tokens 32768 \
|
||||
--quantization "wint4"
|
||||
```
|
||||
使用如下命令请求模型服务
|
||||
```bash
|
||||
@@ -92,13 +98,13 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": "The largest ocean is"}
|
||||
{"role": "user", "content": "Where is Beijing?"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
成功运行后,可以查看到推理结果的生成,样例如下
|
||||
```json
|
||||
{"id":"chatcmpl-5cd96f3b-eff3-4dc0-8aa2-8b5d7b7b86f2","object":"chat.completion","created":1751167862,"model":"default","choices":[{"index":0,"message":{"role":"assistant","content":"3. **Pacific Ocean**: The Pacific Ocean is the largest and deepest of the world's oceans. It covers an area of approximately 181,344,000 square kilometers, which is more than 30% of the Earth's surface. It is located between the Americas to the west and east, and Asia and Australia to the north and south. The Pacific Ocean is known for its vastness, diverse marine life, and numerous islands.\n\nIn summary, the largest ocean in the world is the Pacific Ocean.","reasoning_content":null,"tool_calls":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"total_tokens":127,"completion_tokens":116,"prompt_tokens_details":{"cached_tokens":0}}}
|
||||
{"id":"chatcmpl-20f1210d-6943-4110-ad2d-c76ba11604ad","object":"chat.completion","created":1751621261,"model":"default","choices":[{"index":0,"message":{"role":"assistant","content":"Beijing is the capital city of the People's Republic of China, located in the northern part of the country. It is situated in the North China Plain, bordered by the mountains to the west, north, and northeast. Beijing serves as China's political, cultural, and international exchange center, playing a crucial role in the nation's development and global interactions.","reasoning_content":null,"tool_calls":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"total_tokens":88,"completion_tokens":77,"prompt_tokens_details":{"cached_tokens":0}}}
|
||||
```
|
||||
### 4. 精度测试:(这将花费您 60~180min 时间)
|
||||
准备精度脚本 ```bench_gsm8k.py``` 置于 ```/home/workspace/benchmark/``` ,并修改采样参数,如:
|
||||
@@ -119,10 +125,10 @@ data = {
|
||||
执行以下命令启动精度测试
|
||||
```bash
|
||||
cd /home/workspace/benchmark/
|
||||
python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 2
|
||||
python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 8
|
||||
```
|
||||
执行成功运行后,当前目录可以查看到精度结果的生成,文件为 ```result.jsonl```,样例如下(部分数据集,仅示例)
|
||||
执行成功运行后,当前目录可以查看到精度结果的生成,文件为 ```result.jsonl```,样例如下
|
||||
```json
|
||||
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 365.548, "accuracy": 0.967, "num_requests": 30, "other": {"num_questions": 30, "parallel": 2}}
|
||||
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}}
|
||||
```
|
||||
|
||||
|
@@ -19,7 +19,7 @@ from typing import Optional
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.incubate.nn.functional import fused_bias_act
|
||||
from paddle.incubate.nn.functional import fused_bias_act, swiglu
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -66,6 +66,8 @@ class SiluAndMul(nn.Layer):
|
||||
if current_platform.is_cuda() or current_platform.is_xpu(
|
||||
) or current_platform.is_iluvatar():
|
||||
self.forward = self.forward_cuda
|
||||
elif current_platform.is_gcu():
|
||||
self.forward = self.forward_gcu
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -123,3 +125,18 @@ class SiluAndMul(nn.Layer):
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
|
||||
def forward_gcu(self, x):
|
||||
"""
|
||||
Forward propagation of the custom activation layer.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor to the activation layer.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor.
|
||||
"""
|
||||
out = swiglu(x)
|
||||
if self.bias is not None:
|
||||
out = out + self.bias
|
||||
return out
|
||||
|
@@ -16,14 +16,24 @@
|
||||
all backends methods
|
||||
"""
|
||||
|
||||
from .xpu import *
|
||||
from .npu import *
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
__all__ = []
|
||||
from . import npu
|
||||
if hasattr(npu, '__all__'):
|
||||
__all__.extend(npu.__all__)
|
||||
|
||||
from . import xpu
|
||||
if hasattr(xpu, '__all__'):
|
||||
__all__.extend(xpu.__all__)
|
||||
|
||||
if current_platform.is_xpu():
|
||||
from . import xpu
|
||||
from .xpu import *
|
||||
if hasattr(xpu, '__all__'):
|
||||
__all__.extend(xpu.__all__)
|
||||
|
||||
if current_platform.is_npu():
|
||||
from . import npu
|
||||
from .npu import *
|
||||
if hasattr(npu, '__all__'):
|
||||
__all__.extend(npu.__all__)
|
||||
|
||||
if current_platform.is_gcu():
|
||||
from . import gcu
|
||||
from .gcu import *
|
||||
if hasattr(gcu, '__all__'):
|
||||
__all__.extend(gcu.__all__)
|
||||
|
31
fastdeploy/model_executor/layers/backends/gcu/__init__.py
Normal file
31
fastdeploy/model_executor/layers/backends/gcu/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
gcu backend methods
|
||||
"""
|
||||
|
||||
from .attention.flash_attn_backend import GCUFlashAttnBackend
|
||||
from .attention.mem_efficient_attn_backend import GCUMemEfficientAttnBackend
|
||||
from .moe.fused_moe_method_gcu_backend import (GCUFusedMoeMethod,
|
||||
GCUWeightOnlyMoEMethod)
|
||||
from .quantization.weight_only import GCUWeightOnlyLinearMethod
|
||||
|
||||
__all__ = [
|
||||
'GCUFlashAttnBackend',
|
||||
'GCUMemEfficientAttnBackend',
|
||||
'GCUFusedMoeMethod',
|
||||
'GCUWeightOnlyMoEMethod',
|
||||
'GCUWeightOnlyLinearMethod',
|
||||
]
|
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .flash_attn_backend import GCUFlashAttnBackend
|
||||
from .mem_efficient_attn_backend import GCUMemEfficientAttnBackend
|
||||
|
||||
__all__ = [
|
||||
"GCUFlashAttnBackend",
|
||||
"GCUMemEfficientAttnBackend",
|
||||
]
|
@@ -0,0 +1,287 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from paddle._typing.dtype_like import _DTypeLiteral
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta, ForwardMode
|
||||
|
||||
from fastdeploy.model_executor.ops.gcu import (fused_rotary_embedding,
|
||||
mem_efficient_attention,
|
||||
flash_attn_var_len)
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class GCUFlashAttnMetadata(AttentionMetadata):
|
||||
"""
|
||||
GCUFlashAttnMetadata
|
||||
"""
|
||||
forward_mode: ForwardMode = ForwardMode.MIXED
|
||||
|
||||
_dtype: _DTypeLiteral = paddle.bfloat16
|
||||
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||
seq_lens_decoder: Optional[paddle.Tensor] = None
|
||||
seq_lens_this_time: Optional[paddle.Tensor] = None
|
||||
cum_offsets: Optional[paddle.Tensor] = None
|
||||
padding_offset: Optional[paddle.Tensor] = None
|
||||
|
||||
cu_seqlens_q: Optional[paddle.Tensor] = None
|
||||
cu_seqlens_k: Optional[paddle.Tensor] = None
|
||||
caches: Optional[paddle.Tensor] = None
|
||||
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
|
||||
pre_caches_length: int = 0
|
||||
|
||||
|
||||
|
||||
|
||||
class GCUFlashAttnBackend(AttentionBackend):
|
||||
"""
|
||||
GCUFlashAttnBackend backend implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
|
||||
head_dim: int):
|
||||
"""
|
||||
GCUFlashAttnBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: GCUFlashAttnMetadata = None
|
||||
self.block_size = fd_config.parallel_config.block_size
|
||||
self.max_seq_len = fd_config.parallel_config.max_model_len
|
||||
self.max_num_seqs = fd_config.parallel_config.max_num_seqs
|
||||
|
||||
self.causal = getattr(fd_config.model_config, "causal", True)
|
||||
|
||||
self.rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.scaling = 1.0 / (self.head_dim**0.5)
|
||||
self.num_layers = fd_config.model_config.num_layers
|
||||
self.position_ids_base = paddle.arange(self.max_seq_len)
|
||||
|
||||
# TODO(zhengjun): Need to adapt the allocation logic and
|
||||
# temporarily allocate according to fixed size
|
||||
self.all_block_tables: List[List[int]] = None
|
||||
self.all_slot_mapping: List[List[int]] = None
|
||||
|
||||
self.rotary_embs = None
|
||||
self.enable_monitor: bool = bool(os.getenv("FD_GCU_ATTN_MONITOR", False))
|
||||
|
||||
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = GCUFlashAttnMetadata()
|
||||
|
||||
metadata.forward_mode = forward_meta.forward_mode
|
||||
|
||||
metadata._dtype = paddle.get_default_dtype()
|
||||
|
||||
metadata.seq_lens_encoder = forward_meta.seq_lens_encoder
|
||||
metadata.seq_lens_decoder = forward_meta.seq_lens_decoder
|
||||
metadata.seq_lens_this_time = forward_meta.seq_lens_this_time
|
||||
metadata.cum_offsets = forward_meta.cum_offsets
|
||||
metadata.padding_offset = forward_meta.padding_offset
|
||||
|
||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||
metadata.cu_seqlens_k = forward_meta.cu_seqlens_k
|
||||
metadata.caches = forward_meta.caches
|
||||
|
||||
# metadata.block_tables = forward_meta.block_tables
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.attn_mask = forward_meta.attn_mask # not init
|
||||
|
||||
metadata.pre_caches_length = forward_meta.pre_caches_length # not inited
|
||||
|
||||
self.attention_metadata = metadata
|
||||
|
||||
if self.rotary_embs is None:
|
||||
self.rotary_embs = metadata.rotary_embs.reshape((-1, self.head_dim))
|
||||
|
||||
# some info for attention
|
||||
self.seq_lens_this_time_list = forward_meta.seq_lens_this_time.tolist() # List[int]
|
||||
self.seq_lens_encoder_list = forward_meta.seq_lens_encoder.tolist() # List[List[int]]
|
||||
self.seq_lens_decoder_list = forward_meta.seq_lens_decoder.tolist() # List[List[int]]
|
||||
self.seq_lens_sum = np.sum(self.seq_lens_this_time_list)
|
||||
self.max_seq_len_this_time = np.max(self.seq_lens_this_time_list)
|
||||
|
||||
num_seqs = forward_meta.seq_lens_this_time.shape[0]
|
||||
|
||||
|
||||
self.is_decoder = all(x[0] == 0 for x in self.seq_lens_encoder_list)
|
||||
self.is_all_prefill = all(x[0] == 0 for x in self.seq_lens_decoder_list)
|
||||
|
||||
# block_tables and slot_mapping
|
||||
if self.all_slot_mapping is None:
|
||||
max_num_blocks_per_seq = (self.max_seq_len + self.block_size - 1) // self.block_size
|
||||
total_blocks = max_num_blocks_per_seq * self.max_num_seqs
|
||||
self.all_block_tables = np.arange(0, total_blocks, dtype=np.int32).reshape((self.max_num_seqs, max_num_blocks_per_seq)).tolist()
|
||||
self.all_slot_mapping = np.arange(0, total_blocks * self.block_size, dtype=np.int32).reshape((self.max_num_seqs, -1)).tolist()
|
||||
|
||||
block_tables = []
|
||||
slot_mapping = []
|
||||
cache_slot_range = []
|
||||
cache_lens = []
|
||||
position_ids = []
|
||||
for seq_idx in range(num_seqs):
|
||||
cache_len = None
|
||||
if self.seq_lens_encoder_list[seq_idx][0] != 0: # prefill
|
||||
cache_len = 0
|
||||
elif self.seq_lens_decoder_list[seq_idx][0] != 0: # decode
|
||||
cache_len = self.seq_lens_decoder_list[seq_idx][0]
|
||||
# else: doesnot have req in this seq_idx
|
||||
|
||||
if cache_len is not None:
|
||||
lens_this_time = self.seq_lens_this_time_list[seq_idx]
|
||||
start = cache_len
|
||||
end = start + lens_this_time
|
||||
slot_mapping.extend(self.all_slot_mapping[seq_idx][start:end])
|
||||
cache_slot_range.extend(self.all_slot_mapping[seq_idx][0:end])
|
||||
cache_lens.append(end)
|
||||
block_tables.append(self.all_block_tables[seq_idx])
|
||||
position_ids.extend(self.position_ids_base[start:end])
|
||||
|
||||
self.block_tables = paddle.to_tensor(block_tables, dtype="int32")
|
||||
self.slot_mapping = paddle.to_tensor(slot_mapping, dtype="int32")
|
||||
self.cache_slot_range = paddle.to_tensor(cache_slot_range, dtype="int32")
|
||||
self.position_ids = paddle.to_tensor(position_ids, dtype="int32")
|
||||
self.position_ids = self.position_ids.reshape_((1, -1))
|
||||
|
||||
if self.enable_monitor:
|
||||
logger.info(f"[FD_DEBUG] init_attention_metadata, position_ids:\n{self.position_ids}")
|
||||
|
||||
cu_query_lens_data = [0]
|
||||
for seq_idx in range(num_seqs):
|
||||
if self.seq_lens_this_time_list[seq_idx] != 0:
|
||||
cu_query_lens_data.append(self.seq_lens_this_time_list[seq_idx])
|
||||
cu_query_lens = np.array(cu_query_lens_data, dtype=np.int32).cumsum(axis=0)
|
||||
|
||||
self.cu_query_lens = paddle.to_tensor(cu_query_lens, dtype="int32")
|
||||
self.seqused_k = paddle.to_tensor(cache_lens, dtype="int32")
|
||||
self.max_seqlen_q = self.max_seq_len_this_time
|
||||
self.max_seqlen_k = np.max(cache_lens)
|
||||
|
||||
|
||||
def get_attntion_meta(self):
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
):
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
# [total_tokens, kv_num_heads, head_dim]
|
||||
return (max_num_blocks * self.block_size,
|
||||
self.kv_num_heads,
|
||||
self.head_dim)
|
||||
|
||||
@paddle.no_grad()
|
||||
def forward_mixed(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
"""Run a forward for mixed."""
|
||||
token_num = qkv.shape[0]
|
||||
q_size = self.num_heads * self.head_dim
|
||||
kv_size = self.kv_num_heads * self.head_dim
|
||||
num_or_sections = [q_size, kv_size, kv_size]
|
||||
query, key, value = paddle.split(qkv, num_or_sections=num_or_sections, axis=-1)
|
||||
|
||||
query = query.reshape_((1, -1, self.num_heads, self.head_dim))
|
||||
key = key.reshape_((1, -1, self.kv_num_heads, self.head_dim))
|
||||
|
||||
|
||||
# 1. Rope
|
||||
if self.rotary_embs.dtype != query.dtype:
|
||||
self.rotary_embs = paddle.cast(self.rotary_embs, query.dtype)
|
||||
|
||||
query, key = fused_rotary_embedding(
|
||||
query,
|
||||
key,
|
||||
self.rotary_embs,
|
||||
self.position_ids,
|
||||
layer.use_neox_rotary_style
|
||||
)
|
||||
|
||||
# 2. Save kv cache
|
||||
# shape: [total_tokens, kv_num_heads, head_dim]
|
||||
key = key.reshape_((-1, self.kv_num_heads, self.head_dim))
|
||||
value = value.reshape_((-1, self.kv_num_heads, self.head_dim))
|
||||
key_caches = forward_meta.caches[2 * layer.layer_id]
|
||||
value_caches = forward_meta.caches[2 * layer.layer_id + 1]
|
||||
key_caches[self.slot_mapping, :, :] = key
|
||||
value_caches[self.slot_mapping, :, :] = value
|
||||
|
||||
# 3. calc attn
|
||||
query = query.reshape_((-1, self.num_heads, self.head_dim))
|
||||
key_caches = key_caches.reshape((-1, self.block_size, self.kv_num_heads, self.head_dim))
|
||||
value_caches = value_caches.reshape((-1, self.block_size, self.kv_num_heads, self.head_dim))
|
||||
res = flash_attn_var_len(
|
||||
query=query,
|
||||
key=key_caches,
|
||||
value=value_caches,
|
||||
cu_seqlens_q=self.cu_query_lens,
|
||||
cu_seqlens_k=None,
|
||||
seqused_k=self.seqused_k,
|
||||
leftpad_k=None,
|
||||
block_table=self.block_tables,
|
||||
alibi_slopes=None,
|
||||
max_seqlen_q=self.max_seqlen_q,
|
||||
max_seqlen_k=self.max_seqlen_k,
|
||||
p_dropout=0.0,
|
||||
softmax_scale=self.scaling,
|
||||
zero_tensors=False,
|
||||
is_causal=self.causal,
|
||||
window_size_left=-1,
|
||||
window_size_right=-1,
|
||||
softcap=0.0,
|
||||
return_softmax=False,
|
||||
)
|
||||
res = res.reshape_((token_num, -1))
|
||||
return res
|
||||
|
@@ -0,0 +1,357 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from paddle._typing.dtype_like import _DTypeLiteral
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta, ForwardMode
|
||||
|
||||
from fastdeploy.model_executor.ops.gcu import (fused_rotary_embedding,
|
||||
mem_efficient_attention,
|
||||
flash_attn_var_len)
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
@dataclass
|
||||
class GCUMemEfficientAttnMetadata(AttentionMetadata):
|
||||
"""
|
||||
GCUMemEfficientAttnMetadata
|
||||
"""
|
||||
forward_mode: ForwardMode = ForwardMode.MIXED
|
||||
_dtype: _DTypeLiteral = paddle.bfloat16
|
||||
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||
seq_lens_decoder: Optional[paddle.Tensor] = None
|
||||
seq_lens_this_time: Optional[paddle.Tensor] = None
|
||||
cum_offsets: Optional[paddle.Tensor] = None
|
||||
padding_offset: Optional[paddle.Tensor] = None
|
||||
|
||||
cu_seqlens_q: Optional[paddle.Tensor] = None
|
||||
cu_seqlens_k: Optional[paddle.Tensor] = None
|
||||
caches: Optional[paddle.Tensor] = None
|
||||
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
|
||||
pre_caches_length: int = 0
|
||||
|
||||
|
||||
|
||||
|
||||
class GCUMemEfficientAttnBackend(AttentionBackend):
|
||||
"""
|
||||
GCUMemEfficientAttnBackend backend implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
|
||||
head_dim: int):
|
||||
"""
|
||||
GCUMemEfficientAttnBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: GCUMemEfficientAttnMetadata = None
|
||||
self.block_size = fd_config.parallel_config.block_size
|
||||
self.max_seq_len = fd_config.parallel_config.max_model_len
|
||||
self.max_num_seqs = fd_config.parallel_config.max_num_seqs
|
||||
|
||||
self.causal = getattr(fd_config.model_config, "causal", True)
|
||||
|
||||
self.rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.scaling = 1.0 / (self.head_dim**0.5)
|
||||
self.num_layers = fd_config.model_config.num_layers
|
||||
self.position_ids_base = paddle.arange(self.max_seq_len)
|
||||
|
||||
# TODO(zhengjun): Need to adapt the allocation logic and
|
||||
# temporarily allocate according to fixed size
|
||||
self.all_block_tables: List[List[int]] = None
|
||||
self.all_slot_mapping: List[List[int]] = None
|
||||
|
||||
self.rotary_embs = None
|
||||
self.use_paddle_native_sdpa = False
|
||||
|
||||
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = GCUMemEfficientAttnMetadata()
|
||||
|
||||
metadata.forward_mode = forward_meta.forward_mode
|
||||
|
||||
metadata._dtype = paddle.get_default_dtype()
|
||||
|
||||
metadata.seq_lens_encoder = forward_meta.seq_lens_encoder
|
||||
metadata.seq_lens_decoder = forward_meta.seq_lens_decoder
|
||||
metadata.seq_lens_this_time = forward_meta.seq_lens_this_time
|
||||
metadata.cum_offsets = forward_meta.cum_offsets
|
||||
metadata.padding_offset = forward_meta.padding_offset
|
||||
|
||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||
metadata.cu_seqlens_k = forward_meta.cu_seqlens_k
|
||||
metadata.caches = forward_meta.caches
|
||||
|
||||
# metadata.block_tables = forward_meta.block_tables
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.attn_mask = forward_meta.attn_mask # not init
|
||||
|
||||
metadata.pre_caches_length = forward_meta.pre_caches_length # not inited
|
||||
|
||||
|
||||
self.attention_metadata = metadata
|
||||
|
||||
if self.rotary_embs is None:
|
||||
self.rotary_embs = metadata.rotary_embs.reshape((-1, self.head_dim))
|
||||
|
||||
# some info for attention
|
||||
self.seq_lens_this_time_list = forward_meta.seq_lens_this_time.tolist() # List[int]
|
||||
self.seq_lens_encoder_list = forward_meta.seq_lens_encoder.tolist() # List[List[int]]
|
||||
self.seq_lens_decoder_list = forward_meta.seq_lens_decoder.tolist() # List[List[int]]
|
||||
self.seq_lens_sum = np.sum(self.seq_lens_this_time_list)
|
||||
self.max_seq_len_this_time = np.max(self.seq_lens_this_time_list)
|
||||
|
||||
num_seqs = forward_meta.seq_lens_this_time.shape[0]
|
||||
|
||||
|
||||
self.is_decoder = all(x[0] == 0 for x in self.seq_lens_encoder_list)
|
||||
self.is_all_prefill = all(x[0] == 0 for x in self.seq_lens_decoder_list)
|
||||
|
||||
|
||||
# block_tables and slot_mapping
|
||||
if self.all_slot_mapping is None:
|
||||
max_num_blocks_per_seq = (self.max_seq_len + self.block_size - 1) // self.block_size
|
||||
total_blocks = max_num_blocks_per_seq * self.max_num_seqs
|
||||
self.all_block_tables = np.arange(0, total_blocks, dtype=np.int32).reshape((self.max_num_seqs, max_num_blocks_per_seq)).tolist()
|
||||
self.all_slot_mapping = np.arange(0, total_blocks * self.block_size, dtype=np.int32).reshape((self.max_num_seqs, -1)).tolist()
|
||||
|
||||
block_tables = []
|
||||
slot_mapping = []
|
||||
cache_slot_range = []
|
||||
cache_lens = []
|
||||
query_lens = []
|
||||
cached_kv_lens = []
|
||||
cached_kv_slot_range = []
|
||||
position_ids = []
|
||||
for seq_idx in range(num_seqs):
|
||||
cache_len = None
|
||||
if self.seq_lens_encoder_list[seq_idx][0] != 0: # prefill
|
||||
cache_len = 0
|
||||
elif self.seq_lens_decoder_list[seq_idx][0] != 0: # decode
|
||||
cache_len = self.seq_lens_decoder_list[seq_idx][0]
|
||||
# else: doesnot have req in this seq_idx
|
||||
|
||||
if cache_len is not None:
|
||||
lens_this_time = self.seq_lens_this_time_list[seq_idx]
|
||||
start = cache_len
|
||||
end = start + lens_this_time
|
||||
slot_mapping.extend(self.all_slot_mapping[seq_idx][start:end])
|
||||
cache_slot_range.extend(self.all_slot_mapping[seq_idx][0:end])
|
||||
cache_lens.append(end)
|
||||
block_tables.append(self.all_block_tables[seq_idx])
|
||||
position_ids.extend(self.position_ids_base[start:end])
|
||||
query_lens.append(lens_this_time)
|
||||
cached_kv_lens.append(end)
|
||||
cached_kv_slot_range.append([self.all_slot_mapping[seq_idx][0], self.all_slot_mapping[seq_idx][end]])
|
||||
|
||||
|
||||
|
||||
self.block_tables = paddle.to_tensor(block_tables, dtype="int32")
|
||||
self.slot_mapping = paddle.to_tensor(slot_mapping, dtype="int32")
|
||||
self.cache_slot_range = paddle.to_tensor(cache_slot_range, dtype="int32")
|
||||
self.position_ids = paddle.to_tensor(position_ids, dtype="int32")
|
||||
self.position_ids = self.position_ids.reshape_((1, -1))
|
||||
|
||||
logger.info(f"[FD_DEBUG] init_attention_metadata, self.position_ids:\n{self.position_ids}")
|
||||
|
||||
cu_query_lens_data = [0]
|
||||
for seq_idx in range(num_seqs):
|
||||
if self.seq_lens_this_time_list[seq_idx] != 0:
|
||||
cu_query_lens_data.append(self.seq_lens_this_time_list[seq_idx])
|
||||
cu_query_lens = np.array(cu_query_lens_data, dtype=np.int32).cumsum(axis=0)
|
||||
|
||||
self.cu_query_lens = paddle.to_tensor(cu_query_lens, dtype="int32")
|
||||
self.seqused_k = paddle.to_tensor(cache_lens, dtype="int32")
|
||||
self.max_seqlen_q = self.max_seq_len_this_time
|
||||
self.max_seqlen_k = np.max(cache_lens)
|
||||
|
||||
self.query_lens = query_lens
|
||||
self.cached_kv_lens = cached_kv_lens
|
||||
self.cached_kv_slot_range = cached_kv_slot_range
|
||||
|
||||
|
||||
def get_attntion_meta(self):
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
):
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
# [total_tokens, kv_num_heads, head_dim]
|
||||
return (max_num_blocks * self.block_size,
|
||||
self.kv_num_heads,
|
||||
self.head_dim)
|
||||
|
||||
@paddle.no_grad()
|
||||
def forward_mixed(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
) -> paddle.Tensor:
|
||||
"""Run a forward for mixed."""
|
||||
token_num = qkv.shape[0]
|
||||
q_size = self.num_heads * self.head_dim
|
||||
kv_size = self.kv_num_heads * self.head_dim
|
||||
num_or_sections = [q_size, kv_size, kv_size]
|
||||
query, key, value = paddle.split(qkv, num_or_sections=num_or_sections, axis=-1)
|
||||
|
||||
query = query.reshape_((1, -1, self.num_heads, self.head_dim))
|
||||
key = key.reshape_((1, -1, self.kv_num_heads, self.head_dim))
|
||||
|
||||
|
||||
# 1. Rope
|
||||
if self.rotary_embs.dtype != query.dtype:
|
||||
self.rotary_embs = paddle.cast(self.rotary_embs, query.dtype)
|
||||
|
||||
query, key = fused_rotary_embedding(
|
||||
query,
|
||||
key,
|
||||
self.rotary_embs,
|
||||
self.position_ids,
|
||||
layer.use_neox_rotary_style
|
||||
)
|
||||
|
||||
# 2. Save kv cache
|
||||
# shape: [total_tokens, kv_num_heads, head_dim]
|
||||
key = key.reshape_((-1, self.kv_num_heads, self.head_dim))
|
||||
value = value.reshape_((-1, self.kv_num_heads, self.head_dim))
|
||||
key_caches = forward_meta.caches[2 * layer.layer_id]
|
||||
value_caches = forward_meta.caches[2 * layer.layer_id + 1]
|
||||
key_caches[self.slot_mapping, :, :] = key
|
||||
value_caches[self.slot_mapping, :, :] = value
|
||||
|
||||
# 3. calc attn
|
||||
query = query.reshape_((-1, self.num_heads, self.head_dim))
|
||||
|
||||
q_start = 0
|
||||
result = paddle.empty_like(query)
|
||||
for idx in range(len(self.query_lens)):
|
||||
q_end = q_start + self.query_lens[idx]
|
||||
kv_start = self.cached_kv_slot_range[idx][0]
|
||||
kv_end = self.cached_kv_slot_range[idx][1]
|
||||
|
||||
q_ = query[q_start:q_end, :, :]
|
||||
k_ = key_caches[kv_start:kv_end, :, :]
|
||||
v_ = value_caches[kv_start:kv_end, :, :]
|
||||
|
||||
if self.use_paddle_native_sdpa:
|
||||
res = self.native_sdpa_impl(
|
||||
q_, k_, v_
|
||||
)
|
||||
else:
|
||||
res = mem_efficient_attention(
|
||||
query=q_.unsqueeze(0),
|
||||
key=k_.unsqueeze(0),
|
||||
value=v_.unsqueeze(0),
|
||||
attn_mask=None,
|
||||
dropout=0.0,
|
||||
softmax_scale=self.scaling,
|
||||
mask_mode=1,
|
||||
seqlens=[0],
|
||||
causal=self.causal,
|
||||
)
|
||||
result[q_start:q_end, :, :] = res
|
||||
q_start = q_end
|
||||
result = result.reshape_((token_num, -1))
|
||||
return result
|
||||
|
||||
|
||||
def get_triangle_upper_mask(self, shape, dtype):
|
||||
# [batch_size, 1, q_seq_len, kv_seq_len]
|
||||
shape[1] = 1
|
||||
q_seq_len = shape[2]
|
||||
kv_seq_len = shape[3]
|
||||
paddle_dtype = dtype # paddle.base.data_feeder.convert_dtype(dtype)
|
||||
mask = paddle.full(shape, paddle.finfo(paddle_dtype).min, dtype=paddle_dtype)
|
||||
mask = paddle.triu(mask, diagonal=kv_seq_len - q_seq_len + 1)
|
||||
return mask
|
||||
|
||||
|
||||
def native_sdpa_impl(self, query, key, value):
|
||||
# input shape: [num_tokens, num_heads, head_dim] -> [1, num_tokens, num_heads, head_dim]
|
||||
q = query.unsqueeze(0)
|
||||
k = key.unsqueeze(0)
|
||||
v = value.unsqueeze(0)
|
||||
batch, q_seq_len, heads, head_dim = q.shape
|
||||
kv_seq_len = k.shape[1]
|
||||
|
||||
# [batch_size, seq_len, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim]
|
||||
q = paddle.transpose(q, [0, 2, 1, 3])
|
||||
k = paddle.transpose(k, [0, 2, 1, 3])
|
||||
v = paddle.transpose(v, [0, 2, 1, 3])
|
||||
|
||||
# GQA
|
||||
if q.shape[1] != k.shape[1]:
|
||||
kv_head = k.shape[1]
|
||||
|
||||
k = k.reshape([batch, kv_head, 1, kv_seq_len, head_dim])
|
||||
k = paddle.tile(k, [1, 1, heads // kv_head, 1, 1])
|
||||
k = k.reshape([batch, heads, kv_seq_len, head_dim])
|
||||
|
||||
v = v.reshape([batch, kv_head, 1, kv_seq_len, head_dim])
|
||||
v = paddle.tile(v, [1, 1, heads // kv_head, 1, 1])
|
||||
v = v.reshape([batch, heads, kv_seq_len, head_dim])
|
||||
|
||||
# matmul and devide by sqrt(head_dim)
|
||||
attn_weights = paddle.matmul(q / math.sqrt(head_dim), k.transpose([0, 1, 3, 2]))
|
||||
|
||||
attention_mask = self.get_triangle_upper_mask(
|
||||
[batch, 1, q_seq_len, kv_seq_len], q.dtype
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = paddle.nn.functional.softmax(
|
||||
attn_weights, axis=-1, dtype="float32"
|
||||
).astype(q.dtype)
|
||||
|
||||
attn_output = paddle.matmul(attn_weights, v)
|
||||
attn_output = attn_output.transpose([0, 2, 1, 3])
|
||||
return attn_output.squeeze(0)
|
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""""
|
||||
gcu moe
|
||||
"""
|
@@ -0,0 +1,402 @@
|
||||
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import \
|
||||
MoEMethodBase
|
||||
from fastdeploy.model_executor.layers.utils import (CpuGuard,
|
||||
create_and_set_parameter,
|
||||
get_tensor)
|
||||
from fastdeploy.model_executor.ops.gcu import (invoke_fused_moe_kernel,
|
||||
moe_align_block_size,
|
||||
topk_softmax,
|
||||
weight_quantize_custom_rtn,
|
||||
weight_quantize_rtn)
|
||||
|
||||
|
||||
class GCUFusedMoeMethod(MoEMethodBase):
|
||||
"""
|
||||
Use GCU to compute Fused MoE.
|
||||
"""
|
||||
def __init__(self, quant_config):
|
||||
super().__init__(quant_config)
|
||||
self.group_size = -1
|
||||
|
||||
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle gcu create weight process.
|
||||
"""
|
||||
# bf16
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0)
|
||||
stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0)
|
||||
for idx, weight_tensor in enumerate(
|
||||
[stacked_ffn1_weights, stacked_ffn2_weights]):
|
||||
# shape [E, K, N] -> [E, N, K]
|
||||
weight_tensor = paddle.transpose(weight_tensor, [0, 2, 1])
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
setattr(
|
||||
layer, weight_name,
|
||||
layer.create_parameter(
|
||||
shape=weight_tensor.shape,
|
||||
dtype=weight_tensor.dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
))
|
||||
getattr(layer, weight_name).set_value(weight_tensor)
|
||||
|
||||
|
||||
@paddle.no_grad()
|
||||
def compute_ffn(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
enable_quant = False
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle gcu compute Fused MoE.
|
||||
"""
|
||||
token_num, hidden_size = x.shape
|
||||
top_k = layer.top_k
|
||||
moe_intermediate_size = layer.moe_intermediate_size
|
||||
num_experts = layer.num_local_experts
|
||||
|
||||
topk_weights = paddle.empty([token_num, top_k], dtype=gate_out.dtype)
|
||||
topk_indices = paddle.empty([token_num, top_k], dtype="int32")
|
||||
token_expert_indices = paddle.empty([token_num, top_k], dtype="int32",)
|
||||
topk_softmax(topk_weights, topk_indices, token_expert_indices, gate_out, norm_topk_prob=True)
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
|
||||
block_size = config["BLOCK_SIZE_M"]
|
||||
max_num_tokens_padded = np.prod(topk_indices.shape) + num_experts * (block_size - 1)
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
sorted_token_ids = paddle.empty([max_num_tokens_padded], dtype="int32")
|
||||
expert_ids = paddle.zeros(shape=[max_num_m_blocks], dtype="int32")
|
||||
num_tokens_post_pad = paddle.empty([1], dtype="int32")
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_pad = moe_align_block_size(
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
topk_indices,
|
||||
num_experts,
|
||||
block_size,
|
||||
)
|
||||
|
||||
intermediate_cache1 = paddle.empty(
|
||||
[token_num, top_k, moe_intermediate_size * 2],
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ffn1_B_scale = layer.moe_ffn1_weight_scale if enable_quant else None
|
||||
ffn1_B_zeros = layer.moe_ffn1_weight_zeros if enable_quant else None
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
x, # input
|
||||
layer.moe_ffn1_weight, # weight
|
||||
intermediate_cache1, # output
|
||||
None, # A_scale
|
||||
ffn1_B_scale, # B_scale
|
||||
ffn1_B_zeros, # B_zp
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
False, # mul_routed_weight
|
||||
top_k,
|
||||
config,
|
||||
enable_quant, # use_int4_w4a16
|
||||
[0, self.group_size], # block_shape
|
||||
)
|
||||
|
||||
intermediate_cache2 = paddle.empty(
|
||||
(token_num, top_k, moe_intermediate_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
|
||||
intermediate_cache1)
|
||||
|
||||
intermediate_cache2 = intermediate_cache2.reshape([-1, moe_intermediate_size])
|
||||
|
||||
intermediate_cache3 = paddle.empty(
|
||||
(token_num, top_k, hidden_size),
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ffn2_B_scale = layer.moe_ffn2_weight_scale if enable_quant else None
|
||||
ffn2_B_zeros = layer.moe_ffn2_weight_zeros if enable_quant else None
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
intermediate_cache2, # input
|
||||
layer.moe_ffn2_weight, # weight
|
||||
intermediate_cache3, # output
|
||||
None, # A_scale
|
||||
ffn2_B_scale, # B_scale
|
||||
ffn2_B_zeros, # B_zp
|
||||
topk_weights,
|
||||
topk_indices,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
True, # mul_routed_weight
|
||||
1,
|
||||
config,
|
||||
enable_quant, # use_int4_w4a16
|
||||
[0, self.group_size], # block_shape
|
||||
)
|
||||
|
||||
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
|
||||
fused_moe_out = intermediate_cache3.sum(axis=1)
|
||||
fused_moe_out = fused_moe_out.reshape_([token_num, hidden_size])
|
||||
|
||||
if layer.tp_size > 1:
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle gcu compute Fused MoE.
|
||||
"""
|
||||
return self.compute_ffn(layer, x, gate_out, enable_quant=False)
|
||||
|
||||
|
||||
def apply_ep_prefill(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP prefill method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def apply_ep_decode(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the EP decoder method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def apply_tp(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle Cutlass compute Fused MoE.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
"""
|
||||
weight only for moe
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config):
|
||||
super().__init__(quant_config)
|
||||
self.quant_config = quant_config
|
||||
self.moe_quant_type = self.quant_config.algo
|
||||
self.pack_num = 1
|
||||
|
||||
assert self.quant_config.algo == "weight_only_int4", \
|
||||
"GCUWeightOnlyMoEMethod only support weight_only_int4, but got:{self.quant_config.algo}"
|
||||
|
||||
self.added_qzeros_attrs = [
|
||||
"moe_ffn1_weight_zeros", "moe_ffn2_weight_zeros"
|
||||
]
|
||||
self.group_size = 64
|
||||
|
||||
self.quant_multi_process_group_size = int(
|
||||
os.getenv("FD_MOE_QUANT_MULTI_PROCESS_GROUP_SIZE", 8)
|
||||
)
|
||||
logger.info(f"GCUWeightOnlyMoEMethod quant_multi_process_group_size: {self.quant_multi_process_group_size}")
|
||||
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle gcu process prequanted weights.
|
||||
"""
|
||||
ffn1_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_key", None)
|
||||
ffn2_expert_weight_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_key", None)
|
||||
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn1_expert_weight_scale_key", None)
|
||||
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
|
||||
"ffn2_expert_weight_scale_key", None)
|
||||
|
||||
ffn1_weights, ffn2_weights = layer.load_experts_weight(
|
||||
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
|
||||
# self.check(layer, ffn1_weights, ffn2_weights)
|
||||
ffn1_weight_scale = []
|
||||
ffn2_weight_scale = []
|
||||
for i in range(layer.num_experts):
|
||||
expert_idx = layer.expert_id_offset + i
|
||||
ffn1_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn1_expert_weight_scale_key.format(expert_idx))))
|
||||
ffn2_weight_scale.append(
|
||||
get_tensor(
|
||||
state_dict.pop(
|
||||
ffn2_expert_weight_scale_key.format(expert_idx))))
|
||||
|
||||
ffn1_weight = paddle.stack(ffn1_weights, axis=0)
|
||||
ffn2_weight = paddle.stack(ffn2_weights, axis=0)
|
||||
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0)
|
||||
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0)
|
||||
|
||||
name_tensor_map = {
|
||||
"moe_ffn1_weight": ffn1_weight,
|
||||
"moe_ffn2_weight": ffn2_weight,
|
||||
"moe_ffn1_weight_scale": ffn1_weight_scale,
|
||||
"moe_ffn2_weight_scale": ffn2_weight_scale
|
||||
}
|
||||
for name, tensor in name_tensor_map.items():
|
||||
create_and_set_parameter(layer, name, tensor)
|
||||
|
||||
|
||||
@paddle.no_grad()
|
||||
def create_weights(self, layer: nn.Layer, state_dict):
|
||||
"""
|
||||
Paddle cutlass create weight process.
|
||||
"""
|
||||
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
|
||||
self.check(layer, ffn1_weights, ffn2_weights)
|
||||
|
||||
|
||||
def quant_worker(p_group_idx, shared_dict, weights, moe_quant_type, group_size):
|
||||
with CpuGuard():
|
||||
p_group_size = len(weights)
|
||||
for group_j in range(p_group_size):
|
||||
# weight shape [K, N] -> [N/2, K] -> [N, K/2]
|
||||
quant_weight, scale = weight_quantize_custom_rtn(
|
||||
weights[group_j],
|
||||
moe_quant_type,
|
||||
group_size # group_size
|
||||
)
|
||||
shared_dict[p_group_size * p_group_idx + group_j] = (
|
||||
quant_weight, scale
|
||||
)
|
||||
|
||||
|
||||
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
|
||||
weight_name = self.added_weight_attrs[idx]
|
||||
scale_name = self.added_scale_attrs[idx]
|
||||
zeros_name = self.added_qzeros_attrs[idx]
|
||||
|
||||
if self.quant_multi_process_group_size > 0:
|
||||
process_group_size = self.quant_multi_process_group_size
|
||||
process_group_num = layer.num_local_experts // process_group_size
|
||||
grouped_weights_num = process_group_num * process_group_size
|
||||
remain_weights_start_idx = grouped_weights_num
|
||||
|
||||
weight_list = [None] * grouped_weights_num
|
||||
weight_scale_list = [None] * grouped_weights_num
|
||||
|
||||
with multiprocessing.Manager() as manager:
|
||||
shared_dict = manager.dict({})
|
||||
processes = []
|
||||
|
||||
for i in range(process_group_num):
|
||||
w = []
|
||||
for j in range(process_group_size):
|
||||
w.append(weight_tensor[process_group_size * i + j].to("cpu"))
|
||||
|
||||
p = multiprocessing.Process(
|
||||
target=quant_worker,
|
||||
args=(i, shared_dict, w, self.moe_quant_type, self.group_size)
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
dict_ = dict(shared_dict)
|
||||
|
||||
for k, v in dict_.items():
|
||||
weight_list[k] = v[0].to(ffn1_weights[0].place)
|
||||
weight_scale_list[k] = v[1].to(ffn1_weights[0].place)
|
||||
else:
|
||||
remain_weights_start_idx = 0
|
||||
|
||||
if remain_weights_start_idx < layer.num_local_experts:
|
||||
for i in range(remain_weights_start_idx, layer.num_local_experts):
|
||||
# weight shape [K, N] -> [N/2, K] -> [N, K/2]
|
||||
quant_weight, scale = weight_quantize_rtn(
|
||||
weight_tensor[i],
|
||||
self.moe_quant_type,
|
||||
self.group_size # group_size
|
||||
)
|
||||
weight_list.append(quant_weight)
|
||||
weight_scale_list.append(scale)
|
||||
quanted_weight = paddle.stack(weight_list, axis=0)
|
||||
create_and_set_parameter(layer, weight_name, quanted_weight)
|
||||
|
||||
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
|
||||
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
|
||||
|
||||
quanted_weight_zeros = quanted_weight_scale * 8
|
||||
create_and_set_parameter(layer, zeros_name, quanted_weight_zeros)
|
||||
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Layer,
|
||||
x: paddle.Tensor,
|
||||
gate_out: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Paddle gcu compute Fused MoE.
|
||||
"""
|
||||
return self.compute_ffn(layer, x, gate_out, enable_quant=True)
|
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""""
|
||||
gcu quantization
|
||||
"""
|
||||
from .weight_only import GCUWeightOnlyLinearMethod
|
||||
|
||||
__all__ = [
|
||||
"GCUWeightOnlyLinearMethod",
|
||||
]
|
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.quantization.weight_only import (
|
||||
WeightOnlyConfig, WeightOnlyLinearMethod)
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.model_executor.ops.gcu import linear_quant, weight_quantize_rtn
|
||||
|
||||
|
||||
class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
|
||||
"""
|
||||
Weight only quantization method for linear layer on GCU
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: WeightOnlyConfig,
|
||||
) -> None:
|
||||
super().__init__(quant_config)
|
||||
self.quant_config = quant_config
|
||||
self.group_size = -1
|
||||
|
||||
|
||||
def create_weights(self, layer):
|
||||
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
|
||||
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
|
||||
|
||||
layer.linear_weight_shape.reverse()
|
||||
if self.quant_config.name() == "wint4":
|
||||
layer.linear_weight_shape[0] //= 2
|
||||
layer.weight_dtype = "int8"
|
||||
layer.linear_weight_scale = layer.create_parameter(
|
||||
shape=linear_weight_scale_shape,
|
||||
dtype=layer._dtype,
|
||||
is_bias=False,
|
||||
)
|
||||
|
||||
|
||||
def process_prequanted_weights(self, layer, state_dict) -> None:
|
||||
"""
|
||||
Process pre-quantized weights before applying them to the model
|
||||
Args:
|
||||
layer: The layer that owns the weights
|
||||
quant_weight: The quantized weights
|
||||
weight_scale: The scale of the quantized weights
|
||||
"""
|
||||
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
|
||||
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
|
||||
layer.linear_weight.set_value(quant_weight)
|
||||
layer.linear_weight_scale.set_value(
|
||||
weight_scale.astype(paddle.get_default_dtype()))
|
||||
|
||||
|
||||
def process_loaded_weights(self, layer, weight) -> None:
|
||||
quanted_weight_tensor, weight_scale_tensor = weight_quantize_rtn(
|
||||
weight,
|
||||
self.quant_config.algo,
|
||||
self.group_size, # group_size
|
||||
)
|
||||
|
||||
layer.linear_weight.set_value(quanted_weight_tensor)
|
||||
layer.linear_weight_scale.set_value(
|
||||
weight_scale_tensor.astype(paddle.get_default_dtype()))
|
||||
|
||||
|
||||
@paddle.no_grad()
|
||||
def apply(self, layer, x):
|
||||
linear_out = linear_quant(
|
||||
lhs=x,
|
||||
rhs=layer.linear_weight,
|
||||
scale=layer.linear_weight_scale,
|
||||
bias=None,
|
||||
group_size=self.group_size,
|
||||
)
|
||||
return linear_out
|
@@ -58,7 +58,7 @@ class LinearBase(nn.Layer):
|
||||
"""
|
||||
super().__init__()
|
||||
if current_platform.is_cuda() or current_platform.is_xpu(
|
||||
) or current_platform.is_iluvatar():
|
||||
) or current_platform.is_iluvatar() or current_platform.is_gcu():
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@@ -20,6 +20,7 @@ from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
class FusedMoE(nn.Layer):
|
||||
@@ -95,8 +96,13 @@ class FusedMoE(nn.Layer):
|
||||
self.moe_quant_type = moe_quant_config.name()
|
||||
else:
|
||||
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
|
||||
from .fused_moe_cutlass_backend import CutlassMoEMethod
|
||||
self.quant_method = CutlassMoEMethod(None)
|
||||
if current_platform.is_cuda():
|
||||
from .fused_moe_cutlass_backend import CutlassMoEMethod
|
||||
self.quant_method = CutlassMoEMethod(None)
|
||||
elif current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.layers.backends import \
|
||||
GCUFusedMoeMethod
|
||||
self.quant_method = GCUFusedMoeMethod(None)
|
||||
|
||||
if self.ep_size > 1:
|
||||
self.quant_method.init_ep(self)
|
||||
|
@@ -19,9 +19,14 @@ from typing import Callable, Dict, Optional
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.ops.gcu import fused_add_rms_norm, rms_norm
|
||||
else:
|
||||
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
|
||||
from .utils import get_tensor
|
||||
@@ -69,7 +74,10 @@ class RMSNorm(nn.Layer):
|
||||
self.weight_key: Optional[str] = f"{prefix}.weight"
|
||||
self.with_weight: bool = self.weight_key is not None
|
||||
self.eps: float = eps
|
||||
self.norm_func: Callable = fused_rms_norm
|
||||
if current_platform.is_gcu():
|
||||
self.norm_func: Callable = fused_add_rms_norm
|
||||
else:
|
||||
self.norm_func: Callable = fused_rms_norm
|
||||
self.linear_bias: Optional[paddle.Tensor] = linear_bias
|
||||
self.quant_scale: Optional[float] = quant_scale
|
||||
self._dtype: str = self._helper.get_default_dtype()
|
||||
@@ -129,19 +137,26 @@ class RMSNorm(nn.Layer):
|
||||
The `residual_output` is the result of applying the normalization and possibly other
|
||||
operations (like linear transformation) on the `residual_input`.
|
||||
"""
|
||||
norm_out = self.norm_func(
|
||||
x,
|
||||
norm_weight=self.ln_weight,
|
||||
norm_bias=None,
|
||||
epsilon=self.eps,
|
||||
begin_norm_axis=self.begin_norm_axis,
|
||||
bias=self.linear_bias,
|
||||
residual=residual_input,
|
||||
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
|
||||
quant_round_type=self.quant_round_type,
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
if current_platform.is_gcu():
|
||||
if residual_input is None:
|
||||
return rms_norm(x, self.ln_weight, self.eps)
|
||||
norm_out = self.norm_func(
|
||||
x, residual_input, self.ln_weight, self.eps
|
||||
)
|
||||
else:
|
||||
norm_out = self.norm_func(
|
||||
x,
|
||||
norm_weight=self.ln_weight,
|
||||
norm_bias=None,
|
||||
epsilon=self.eps,
|
||||
begin_norm_axis=self.begin_norm_axis,
|
||||
bias=self.linear_bias,
|
||||
residual=residual_input,
|
||||
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
|
||||
quant_round_type=self.quant_round_type,
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
if residual_input is not None:
|
||||
return norm_out[0], norm_out[1]
|
||||
else:
|
||||
@@ -193,7 +208,10 @@ class LayerNorm(nn.Layer):
|
||||
self.with_bias: bool = with_bias
|
||||
self.eps: float = eps
|
||||
self.quant_scale: float = quant_scale
|
||||
self.norm_func: Callable = fused_layer_norm
|
||||
if current_platform.is_gcu():
|
||||
self.norm_func: Callable = paddle.nn.functional.layer_norm
|
||||
else:
|
||||
self.norm_func: Callable = fused_layer_norm
|
||||
self.linear_bias: Optional[paddle.Tensor] = linear_bias
|
||||
self._dtype: str = self._helper.get_default_dtype()
|
||||
self._norm_weight_dtype: str = "float32"
|
||||
@@ -279,19 +297,40 @@ class LayerNorm(nn.Layer):
|
||||
else:
|
||||
raise NotImplementedError("Iluvatar does not support yet!")
|
||||
|
||||
norm_out = self.norm_func(
|
||||
x,
|
||||
norm_weight=self.ln_weight,
|
||||
norm_bias=self.ln_bias,
|
||||
epsilon=self.eps,
|
||||
begin_norm_axis=1,
|
||||
bias=self.linear_bias,
|
||||
residual=residual_input,
|
||||
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
|
||||
quant_round_type=self.quant_round_type,
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
if current_platform.is_gcu():
|
||||
if residual_input is not None:
|
||||
y = x + residual_input
|
||||
out = self.norm_func(
|
||||
x=y,
|
||||
normalized_shape=y.shape[1:],
|
||||
weight=self.ln_weight,
|
||||
bias=self.linear_bias,
|
||||
epsilon=self.eps,
|
||||
)
|
||||
return out, y
|
||||
else:
|
||||
out = self.norm_func(
|
||||
x=x,
|
||||
normalized_shape=x.shape[1:],
|
||||
weight=self.ln_weight,
|
||||
bias=self.linear_bias,
|
||||
epsilon=self.eps,
|
||||
)
|
||||
return out
|
||||
else:
|
||||
norm_out = self.norm_func(
|
||||
x,
|
||||
norm_weight=self.ln_weight,
|
||||
norm_bias=self.ln_bias,
|
||||
epsilon=self.eps,
|
||||
begin_norm_axis=1,
|
||||
bias=self.linear_bias,
|
||||
residual=residual_input,
|
||||
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
|
||||
quant_round_type=self.quant_round_type,
|
||||
quant_max_bound=self.quant_max_bound,
|
||||
quant_min_bound=self.quant_min_bound,
|
||||
)
|
||||
if residual_input is not None:
|
||||
return norm_out[0], norm_out[1]
|
||||
else:
|
||||
|
@@ -66,6 +66,13 @@ class WeightOnlyConfig(QuantConfigBase):
|
||||
return XPUWeightOnlyMoEMethod(self)
|
||||
else:
|
||||
return XPUWeightOnlyLinearMethod(self)
|
||||
elif current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.layers.backends import (
|
||||
GCUWeightOnlyLinearMethod, GCUWeightOnlyMoEMethod)
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GCUWeightOnlyMoEMethod(self)
|
||||
else:
|
||||
return GCUWeightOnlyLinearMethod(self)
|
||||
else:
|
||||
if isinstance(layer, FusedMoE):
|
||||
if layer.use_method == "cutlass":
|
||||
|
@@ -55,6 +55,10 @@ class ErnieRotaryEmbedding:
|
||||
dtype="float32")
|
||||
emb = paddle.stack([freqs, freqs], axis=-1).reshape(
|
||||
(bsz, max_seq_len, self.rotary_dim))
|
||||
elif current_platform.is_gcu():
|
||||
# shape: [B, S, D]
|
||||
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
|
||||
return rot_emb
|
||||
else:
|
||||
# shape: [B, S, D/2]
|
||||
rot_emb = paddle.zeros(
|
||||
@@ -95,6 +99,10 @@ class QwenRotaryEmbedding:
|
||||
# shape: [B, S, D/2]
|
||||
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"),
|
||||
inv_freq)
|
||||
if current_platform.is_gcu():
|
||||
# shape: [B, S, D]
|
||||
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
|
||||
return rot_emb
|
||||
# shape: [B, S, 1, D]
|
||||
emb = paddle.concat([freqs, freqs], axis=-1).reshape(
|
||||
(bsz, max_seq_len, 1, self.rotary_dim))
|
||||
|
@@ -79,6 +79,21 @@ def apply_penalty_multi_scores(
|
||||
min_dec_lens,
|
||||
eos_token_ids,
|
||||
)
|
||||
elif current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.ops.gcu import \
|
||||
get_token_penalty_multi_scores
|
||||
logits = get_token_penalty_multi_scores(
|
||||
pre_token_ids,
|
||||
logits,
|
||||
repetition_penalties,
|
||||
frequency_penalties,
|
||||
presence_penalties,
|
||||
temperature,
|
||||
bad_words_token_ids,
|
||||
step_idx,
|
||||
min_dec_lens,
|
||||
eos_token_ids,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@@ -19,7 +19,11 @@ from typing import Literal, Optional
|
||||
import paddle
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.ops.gcu import \
|
||||
top_p_sampling as gcu_top_p_sampling
|
||||
|
||||
def top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
@@ -46,13 +50,16 @@ def top_p_sampling(
|
||||
ids = rejection_top_p_sampling(x, ps, seed)
|
||||
_ = None
|
||||
else:
|
||||
_, ids = paddle.tensor.top_p_sampling(x,
|
||||
ps,
|
||||
threshold=threshold,
|
||||
topp_seed=topp_seed,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode=mode)
|
||||
if current_platform.is_gcu():
|
||||
_, ids = gcu_top_p_sampling(x, ps)
|
||||
else:
|
||||
_, ids = paddle.tensor.top_p_sampling(x,
|
||||
ps,
|
||||
threshold=threshold,
|
||||
topp_seed=topp_seed,
|
||||
seed=seed,
|
||||
k=k,
|
||||
mode=mode)
|
||||
return _, ids
|
||||
|
||||
|
||||
|
@@ -171,7 +171,7 @@ class Sampler(nn.Layer):
|
||||
"""
|
||||
super().__init__()
|
||||
if current_platform.is_cuda() or current_platform.is_xpu(
|
||||
) or current_platform.is_iluvatar():
|
||||
) or current_platform.is_iluvatar() or current_platform.is_gcu():
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
@@ -17,5 +17,6 @@ from . import cpu
|
||||
from . import xpu
|
||||
from . import npu
|
||||
from . import iluvatar
|
||||
from . import gcu
|
||||
|
||||
__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar"]
|
||||
__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar", "gcu"]
|
||||
|
116
fastdeploy/model_executor/ops/gcu/__init__.py
Normal file
116
fastdeploy/model_executor/ops/gcu/__init__.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
""" fastdeploy gcu ops """
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
from fastdeploy.import_ops import import_custom_ops, rename_imported_op
|
||||
|
||||
PACKAGE = "fastdeploy.model_executor.ops.gcu"
|
||||
|
||||
import_custom_ops(PACKAGE, ".fastdeploy_ops", globals())
|
||||
|
||||
if current_platform.is_gcu():
|
||||
from paddle_custom_device.gcu.ops import (invoke_fused_moe_kernel, # noqa: F401,E402
|
||||
moe_align_block_size, top_p_sampling, # noqa: F401
|
||||
topk_softmax, # noqa: F401
|
||||
weight_quantize_custom_rtn, # noqa: F401
|
||||
weight_quantize_rtn) # noqa: F401
|
||||
|
||||
# ###################### Ops from PaddleCustomDevice ####################
|
||||
rename_imported_op(
|
||||
old_name="fused_rotary_embedding_gcu",
|
||||
new_name="fused_rotary_embedding",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="reshape_and_cache_gcu",
|
||||
new_name="reshape_and_cache",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="paged_attention_gcu",
|
||||
new_name="paged_attention",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="mem_efficient_attention_gcu",
|
||||
new_name="mem_efficient_attention",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="flash_attn_var_len_gcu",
|
||||
new_name="flash_attn_var_len",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="rms_norm_gcu",
|
||||
new_name="rms_norm",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="fused_add_rms_norm_op",
|
||||
new_name="fused_add_rms_norm",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="linear_quant_gcu",
|
||||
new_name="linear_quant",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
|
||||
# ###################### CPU OPS ####################
|
||||
rename_imported_op(
|
||||
old_name="get_padding_offset_gcu",
|
||||
new_name="get_padding_offset",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="update_inputs_gcu",
|
||||
new_name="update_inputs",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="rebuild_padding_gcu",
|
||||
new_name="rebuild_padding",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="get_token_penalty_multi_scores_gcu",
|
||||
new_name="get_token_penalty_multi_scores",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="set_stop_value_multi_ends_gcu",
|
||||
new_name="set_stop_value_multi_ends",
|
||||
global_ns=globals(),
|
||||
)
|
||||
|
||||
rename_imported_op(
|
||||
old_name="set_value_by_flags_and_idx_gcu",
|
||||
new_name="set_value_by_flags_and_idx",
|
||||
global_ns=globals(),
|
||||
)
|
@@ -24,6 +24,11 @@ if current_platform.is_iluvatar():
|
||||
from fastdeploy.model_executor.ops.iluvatar import (
|
||||
get_padding_offset, save_output, set_stop_value_multi_ends,
|
||||
step_paddle, update_inputs)
|
||||
elif current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.ops.gcu import (get_padding_offset,
|
||||
save_output,
|
||||
set_stop_value_multi_ends,
|
||||
update_inputs)
|
||||
else:
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
get_padding_offset, save_output, set_stop_value_multi_ends,
|
||||
@@ -391,6 +396,17 @@ def rebuild_padding(tmp_out: paddle.Tensor,
|
||||
output_padding_offset,
|
||||
max_input_length,
|
||||
)
|
||||
elif current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.ops.gcu import rebuild_padding
|
||||
hidden_states = rebuild_padding(
|
||||
tmp_out,
|
||||
cum_offsets,
|
||||
seq_len_this_time,
|
||||
seq_lens_decoder,
|
||||
seq_lens_encoder,
|
||||
output_padding_offset,
|
||||
max_input_length,
|
||||
)
|
||||
elif current_platform.is_cpu():
|
||||
from fastdeploy.model_executor.ops.cpu import rebuild_padding_cpu
|
||||
hidden_states = rebuild_padding_cpu(
|
||||
|
@@ -124,6 +124,8 @@ class TokenProcessor(object):
|
||||
from fastdeploy.model_executor.ops.xpu import get_output
|
||||
elif current_platform.is_iluvatar():
|
||||
from fastdeploy.model_executor.ops.iluvatar import get_output
|
||||
elif current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.ops.gcu import get_output
|
||||
else:
|
||||
from fastdeploy.model_executor.ops.gpu import (get_output,
|
||||
get_output_ep,
|
||||
|
@@ -22,6 +22,7 @@ from .xpu import XPUPlatform
|
||||
from .npu import NPUPlatform
|
||||
from .dcu import DCUPlatform
|
||||
from .iluvatar import IluvatarPlatform
|
||||
from .gcu import GCUPlatform
|
||||
from .base import _Backend # noqa: F401
|
||||
|
||||
_current_platform = None
|
||||
@@ -42,6 +43,8 @@ def __getattr__(name: str):
|
||||
_current_platform = DCUPlatform()
|
||||
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||
_current_platform = IluvatarPlatform()
|
||||
elif paddle.is_compiled_with_custom_device("gcu"):
|
||||
_current_platform = GCUPlatform()
|
||||
else:
|
||||
_current_platform = CPUPlatform()
|
||||
return _current_platform
|
||||
|
@@ -69,6 +69,12 @@ class Platform:
|
||||
"""
|
||||
return paddle.is_compiled_with_custom_device("iluvatar_gpu")
|
||||
|
||||
def is_gcu(self) -> bool:
|
||||
"""
|
||||
whether platform is gcu
|
||||
"""
|
||||
return paddle.is_compiled_with_custom_device("gcu")
|
||||
|
||||
@classmethod
|
||||
def get_attention_backend_cls(self, selected_backend):
|
||||
"""Get the attention backend"""
|
||||
|
61
fastdeploy/platforms/gcu.py
Normal file
61
fastdeploy/platforms/gcu.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.utils import console_logger as logger
|
||||
|
||||
from .base import Platform, _Backend
|
||||
|
||||
|
||||
class GCUPlatform(Platform):
|
||||
"""
|
||||
gcu platform class
|
||||
"""
|
||||
device_name = "gcu"
|
||||
|
||||
@classmethod
|
||||
def available(self):
|
||||
"""
|
||||
Check whether GCU is available.
|
||||
"""
|
||||
try:
|
||||
assert paddle.base.core.get_custom_device_count('gcu') > 0
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"You are using GCUPlatform, but there is no GCU "
|
||||
"detected on your machine. Maybe GCU devices is not set properly."
|
||||
f"\n Original Error is {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_attention_backend_cls(cls, selected_backend: _Backend):
|
||||
"""
|
||||
get_attention_backend_cls
|
||||
"""
|
||||
if selected_backend == _Backend.NATIVE_ATTN:
|
||||
logger.info("Using GCU mem_efficient ATTN backend.")
|
||||
return ("fastdeploy.model_executor.layers.backends.gcu.attention.mem_efficient_attn_backend.GCUMemEfficientAttnBackend")
|
||||
elif selected_backend == _Backend.APPEND_ATTN:
|
||||
logger.info("Using GCU ATTN backend.")
|
||||
return ("fastdeploy.model_executor.layers.backends.gcu.attention.flash_attn_backend.GCUFlashAttnBackend")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid attention backend you specified.\n"
|
||||
"Now only support [NATIVE_ATTN, APPEND_ATTN] in gcu place."
|
||||
)
|
1186
fastdeploy/worker/gcu_model_runner.py
Normal file
1186
fastdeploy/worker/gcu_model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
142
fastdeploy/worker/gcu_worker.py
Normal file
142
fastdeploy/worker/gcu_worker.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
import gc
|
||||
from typing import List, Optional
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.utils import get_logger
|
||||
from fastdeploy.worker.gcu_model_runner import GCUModelRunner
|
||||
from fastdeploy.worker.output import ModelRunnerOutput
|
||||
from fastdeploy.worker.worker_base import WorkerBase
|
||||
|
||||
logger = get_logger("gcu_worker", "gcu_worker.log")
|
||||
|
||||
|
||||
class GcuWorker(WorkerBase):
|
||||
""" """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
):
|
||||
super().__init__(
|
||||
fd_config=fd_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
)
|
||||
pass
|
||||
|
||||
def init_device(self):
|
||||
""" Initialize device and Construct model runner
|
||||
"""
|
||||
if paddle.is_compiled_with_custom_device("gcu"):
|
||||
# Set evironment variable
|
||||
self.device_ids = self.parallel_config.device_ids.split(",")
|
||||
self.device = f"gcu:{self.local_rank}"
|
||||
paddle.device.set_device(self.device)
|
||||
paddle.set_default_dtype(self.parallel_config.dtype)
|
||||
logger.info(f"GcuWorker init_device:{self.device}, device_ids:{self.device_ids}")
|
||||
|
||||
gc.collect()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
|
||||
# Construct model runner
|
||||
self.model_runner: GCUModelRunner = GCUModelRunner(
|
||||
fd_config=self.fd_config,
|
||||
device=self.device,
|
||||
device_id=self.device_ids[self.local_rank],
|
||||
rank=self.rank,
|
||||
local_rank=self.local_rank)
|
||||
|
||||
def prefill_finished(self):
|
||||
"""
|
||||
check whether prefill stage finished
|
||||
"""
|
||||
return self.model_runner.prefill_finished()
|
||||
|
||||
def determine_available_memory(self) -> int:
|
||||
"""
|
||||
Profiles the peak memory usage of the model to determine how much
|
||||
memory can be used for KV cache without OOMs.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the maximum possible number of GCU and CPU blocks
|
||||
that can be allocated with the remaining free memory.
|
||||
|
||||
Tip:
|
||||
You may limit the usage of GCU memory
|
||||
by adjusting the `gcu_memory_utilization` parameter.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_model(self) -> None:
|
||||
""" """
|
||||
self.model_runner.load_model()
|
||||
|
||||
def get_model(self) -> nn.Layer:
|
||||
""" """
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
""" """
|
||||
pass
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
model_forward_batch: Optional[List[Request]] = None,
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
""" """
|
||||
output = self.model_runner.execute_model(model_forward_batch)
|
||||
return output
|
||||
|
||||
def preprocess_new_task(self, req_dicts: List[Request]) -> None:
|
||||
""" Process new requests and then start the decode loop
|
||||
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
|
||||
and workers and modelrunners should not perceive it.
|
||||
"""
|
||||
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts)
|
||||
|
||||
def graph_optimize_and_warm_up_model(self) -> None:
|
||||
"""
|
||||
Perform the warm-up and the graph optimization
|
||||
"""
|
||||
# 1. Warm up model
|
||||
# NOTE(gongshaotian): may be not need warm_up at this place
|
||||
|
||||
# 2. Triger cuda grpah capture
|
||||
self.model_runner.capture_model()
|
||||
|
||||
def check_health(self) -> bool:
|
||||
""" """
|
||||
return True
|
||||
|
||||
def cal_theortical_kvcache(self) -> int:
|
||||
""" """
|
||||
return self.model_runner.cal_theortical_kvcache()
|
||||
|
||||
def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
|
||||
""" """
|
||||
self.model_runner.update_share_input_block_num(
|
||||
num_gpu_blocks=num_gpu_blocks)
|
@@ -53,6 +53,9 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
|
||||
return IluvatarWorker(fd_config=fd_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank)
|
||||
if current_platform.is_gcu():
|
||||
from fastdeploy.worker.gcu_worker import GcuWorker
|
||||
return GcuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
|
||||
|
||||
|
||||
class PaddleDisWorkerProc():
|
||||
|
4
setup.py
4
setup.py
@@ -167,6 +167,8 @@ def get_device_type():
|
||||
return "npu"
|
||||
elif paddle.is_compiled_with_custom_device('iluvatar_gpu'):
|
||||
return "iluvatar-gpu"
|
||||
elif paddle.is_compiled_with_custom_device('gcu'):
|
||||
return "gcu"
|
||||
else:
|
||||
return "cpu"
|
||||
|
||||
@@ -199,7 +201,7 @@ setup(
|
||||
"model_executor/ops/xpu/libs/*", "model_executor/ops/npu/*",
|
||||
"model_executor/ops/base/*", "model_executor/ops/iluvatar/*",
|
||||
"model_executor/models/*", "model_executor/layers/*",
|
||||
"input/mm_processor/utils/*",
|
||||
"input/mm_processor/utils/*", "model_executor/ops/gcu/*",
|
||||
"version.txt"
|
||||
]
|
||||
},
|
||||
|
Reference in New Issue
Block a user