[GCU] Support gcu platform (#2702)

baseline: e7fa57ebae

Co-authored-by: yongqiangma <xing.wo@163.com>
This commit is contained in:
EnflameGCU
2025-07-08 13:00:52 +08:00
committed by GitHub
parent 26d5d737dd
commit d0f4d6ba3a
33 changed files with 2988 additions and 85 deletions

View File

@@ -113,6 +113,14 @@ function copy_ops(){
return
fi
is_gcu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('gcu'))"`
if [ "$is_gcu" = "True" ]; then
DEVICE_TYPE="gcu"
cp -r ${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gcu
echo -e "gcu ops have been copy to fastdeploy"
return
fi
DEVICE_TYPE="cpu"
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cd ../../../../

View File

@@ -501,6 +501,17 @@ elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
],
),
)
elif paddle.is_compiled_with_custom_device("gcu"):
setup(
name="fastdeploy_ops",
ext_modules=CppExtension(
sources=[
"gpu_ops/save_with_output_msg.cc",
"gpu_ops/get_output.cc",
"gpu_ops/get_output_msg_with_topk.cc",
]
),
)
else:
use_bf16 = envs.FD_CPU_USE_BF16 == "True"

View File

@@ -1,8 +1,8 @@
# Running ERNIE-4.5-21B-A3B with FastDeploy
# Running ERNIE 4.5 Series Models with FastDeploy
The Enflame S60 ([Learn about Enflame](https://www.enflame-tech.com/)) is a next-generation AI inference accelerator card designed for large-scale deployment in data centers. It meets the demands of large language models (LLMs), search/advertising/recommendation systems, and traditional models. Characterized by broad model coverage, user-friendliness, and high portability, it is widely applicable to mainstream inference scenarios such as image and text generation applications, search and recommendation systems, and text/image/speech recognition.
FastDeploy has deeply adapted and optimized the ernie-4_5-21b-a3b-bf16-paddle model for the Enflame S60, achieving a unified inference interface between GCU and GPU. This allows seamless migration of inference tasks without code modifications.
FastDeploy has deeply adapted and optimized the ERNIE 4.5 Series Models for the Enflame S60, achieving a unified inference interface between GCU and GPU. This allows seamless migration of inference tasks without code modifications.
## 🚀 Quick Start 🚀
@@ -27,15 +27,15 @@ lspci | grep S60
3b:00.0 Processing accelerators: Shanghai Enflame Technology Co. Ltd S60 [Enflame] (rev 01)
3c:00.0 Processing accelerators: Shanghai Enflame Technology Co. Ltd S60 [Enflame] (rev 01)
```
### 1. Environment Setup (Estimated time: 510 minutes)
### 1. Environment Setup (Estimated time: 5-10 minutes)
1. Pull the Docker image
```bash
# Note: This image only contains the Paddle development environment, not precompiled PaddlePaddle packages
docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.4.623-ubuntu20-x86_64-gcc84
docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84
```
2. Start the container
```bash
docker run --name paddle-gcu-llm -v /home:/home -v /work:/work --network=host --ipc=host -it --privileged ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.4.623-ubuntu20-x86_64-gcc84 /bin/bash
docker run --name paddle-gcu-llm -v /home:/home -v /work:/work --network=host --ipc=host -it --privileged ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84 /bin/bash
```
3. Obtain and install drivers<br/>
**Full software packages are preloaded in the Docker container. Copy them to an external directory, e.g., ```/home/workspace/deps/```**
@@ -67,25 +67,31 @@ python -m pip install paddle-custom-gcu==3.1.0 -i https://www.paddlepaddle.org.c
7. Install FastDeploy and dependencies
```bash
python -m pip install fastdeploy -i https://www.paddlepaddle.org.cn/packages/stable/gcu/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simplels
apt install python3.10-distutils
# For source compilation, refer to the following steps
git clone https://github.com/PaddlePaddle/FastDeploy
cd FastDeploy
python -m pip install -r requirements.txt --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simplels
bash build.sh 1
```
### 2. Data Preparation (Estimated time: 25 minutes)
### 2. Data Preparation (Estimated time: 2-5 minutes)
Use a trained model for inference on GSM8K dataset:
```bash
mkdir -p /home/workspace/benchmark/ && cd /home/workspace/benchmark/
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
```
Place model weights in a directory, e.g., ```/work/models/ernie-4_5-21b-a3b-bf16-paddle/```
### 3. Inference (Estimated time: 25 minutes)
Place model weights in a directory, e.g., ```/work/models/ERNIE-4.5-300B-A47B-Paddle/```
### 3. Inference (Estimated time: 2-5 minutes)
Start the inference service:
```bash
python -m fastdeploy.entrypoints.openai.api_server \
--model "/work/models/ernie-4_5-21b-a3b-bf16-paddle/" \
--model "/work/models/ERNIE-4.5-300B-A47B-Paddle/" \
--port 8188 \
--metrics-port 8200 \
--tensor-parallel-size 4 \
--max-model-len 8192 \
--num-gpu-blocks-override 1024
--tensor-parallel-size 8 \
--max-model-len 32768 \
--num-gpu-blocks-override 4096 \
--max-num-batched-tokens 32768 \
--quantization "wint4"
```
Query the model service:
```bash
@@ -93,13 +99,13 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "The largest ocean is"}
{"role": "user", "content": "Where is Beijing?"}
]
}'
```
Successful execution returns inference results, e.g.:
```json
{"id":"chatcmpl-5cd96f3b-eff3-4dc0-8aa2-8b5d7b7b86f2","object":"chat.completion","created":1751167862,"model":"default","choices":[{"index":0,"message":{"role":"assistant","content":"3. **Pacific Ocean**: The Pacific Ocean is the largest and deepest of the world's oceans. It covers an area of approximately 181,344,000 square kilometers, which is more than 30% of the Earth's surface. It is located between the Americas to the west and east, and Asia and Australia to the north and south. The Pacific Ocean is known for its vastness, diverse marine life, and numerous islands.\n\nIn summary, the largest ocean in the world is the Pacific Ocean.","reasoning_content":null,"tool_calls":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"total_tokens":127,"completion_tokens":116,"prompt_tokens_details":{"cached_tokens":0}}}
{"id":"chatcmpl-20f1210d-6943-4110-ad2d-c76ba11604ad","object":"chat.completion","created":1751621261,"model":"default","choices":[{"index":0,"message":{"role":"assistant","content":"Beijing is the capital city of the People's Republic of China, located in the northern part of the country. It is situated in the North China Plain, bordered by the mountains to the west, north, and northeast. Beijing serves as China's political, cultural, and international exchange center, playing a crucial role in the nation's development and global interactions.","reasoning_content":null,"tool_calls":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"total_tokens":88,"completion_tokens":77,"prompt_tokens_details":{"cached_tokens":0}}}
```
### 4. Accuracy Testing (Estimated time: 60180 minutes)
Place the accuracy script ```bench_gsm8k.py``` in ```/home/workspace/benchmark/``` and modify sampling parameters, e.g.:
@@ -120,10 +126,10 @@ data = {
Run accuracy tests:
```bash
cd /home/workspace/benchmark/
python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 2
python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 8
```
Upon completion, accuracy results are saved in ```result.jsonl```, e.g.:
```json
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 365.548, "accuracy": 0.967, "num_requests": 30, "other": {"num_questions": 30, "parallel": 2}}
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}}
```

View File

@@ -1,8 +1,8 @@
# 使用 FastDeploy 在燧原 S60 上运行 ERNIE-4.5-21B-A3B模型
# 使用 FastDeploy 在燧原 S60 上运行 ERNIE 4.5 系列模型
燧原 S60[了解燧原](https://www.enflame-tech.com/))是面向数据中心大规模部署的新一代人工智能推理加速卡,满足大语言模型、搜广推及传统模型的需求,具有模型覆盖面广、易用性强、易迁移易部署等特点,可广泛应用于图像及文本生成等应用、搜索与推荐、文本、图像及语音识别等主流推理场景。
FastDeploy 在燧原 S60 上对 ernie-4_5-21b-a3b-bf16-paddle 模型进行了深度适配和优化,实现了 GCU 推理入口和 GPU 的统一,无需修改即可完成推理任务的迁移。
FastDeploy 在燧原 S60 上对 ERNIE 4.5 系列模型进行了深度适配和优化,实现了 GCU 推理入口和 GPU 的统一,无需修改即可完成推理任务的迁移。
## 🚀 快速开始 🚀
@@ -30,11 +30,11 @@ lspci | grep S60
1. 拉取镜像
```bash
# 注意此镜像仅为paddle开发环境镜像中不包含预编译的飞桨安装包
docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.4.623-ubuntu20-x86_64-gcc84
docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84
```
2. 参考如下命令启动容器
```bash
docker run --name paddle-gcu-llm -v /home:/home -v /work:/work --network=host --ipc=host -it --privileged ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.4.623-ubuntu20-x86_64-gcc84 /bin/bash
docker run --name paddle-gcu-llm -v /home:/home -v /work:/work --network=host --ipc=host -it --privileged ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84 /bin/bash
```
3. 获取并安装驱动<br/>
**docker 内提前放置了全量软件包,需拷贝至 docker 外目录,如:```/home/workspace/deps/```**
@@ -63,10 +63,14 @@ python -m pip install paddlepaddle==3.1.0a0 -i https://www.paddlepaddle.org.cn/p
python -m pip install paddle-custom-gcu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/gcu/
# 如想源码编译安装请参考https://github.com/PaddlePaddle/PaddleCustomDevice/blob/develop/backends/gcu/README_cn.md
```
7. 安装 FastDeploy 和 依赖<br/>
7. 安装 FastDeploy <br/>
```bash
python -m pip install fastdeploy -i https://www.paddlepaddle.org.cn/packages/stable/gcu/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simplels
apt install python3.10-distutils
# 如想源码编译安装,请参考如下步骤
git clone https://github.com/PaddlePaddle/FastDeploy
cd FastDeploy
python -m pip install -r requirements.txt --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simplels
bash build.sh 1
```
### 2. 数据准备:(这将花费您 25min 时间)
使用训练好的模型,在 GSM8K 上推理
@@ -74,17 +78,19 @@ apt install python3.10-distutils
mkdir -p /home/workspace/benchmark/ && cd /home/workspace/benchmark/
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
```
准备模型和权重,置于环境目录,如:```/work/models/ernie-4_5-21b-a3b-bf16-paddle/```
准备模型和权重,置于环境目录,如:```/work/models/ERNIE-4.5-300B-A47B-Paddle/```
### 3. 推理:(这将花费您 2~5min 时间)
执行如下命令启动推理服务
```bash
python -m fastdeploy.entrypoints.openai.api_server \
--model "/work/models/ernie-4_5-21b-a3b-bf16-paddle/" \
--model "/work/models/ERNIE-4.5-300B-A47B-Paddle/" \
--port 8188 \
--metrics-port 8200 \
--tensor-parallel-size 4 \
--max-model-len 8192 \
--num-gpu-blocks-override 1024
--tensor-parallel-size 8 \
--max-model-len 32768 \
--num-gpu-blocks-override 4096 \
--max-num-batched-tokens 32768 \
--quantization "wint4"
```
使用如下命令请求模型服务
```bash
@@ -92,13 +98,13 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "The largest ocean is"}
{"role": "user", "content": "Where is Beijing?"}
]
}'
```
成功运行后,可以查看到推理结果的生成,样例如下
```json
{"id":"chatcmpl-5cd96f3b-eff3-4dc0-8aa2-8b5d7b7b86f2","object":"chat.completion","created":1751167862,"model":"default","choices":[{"index":0,"message":{"role":"assistant","content":"3. **Pacific Ocean**: The Pacific Ocean is the largest and deepest of the world's oceans. It covers an area of approximately 181,344,000 square kilometers, which is more than 30% of the Earth's surface. It is located between the Americas to the west and east, and Asia and Australia to the north and south. The Pacific Ocean is known for its vastness, diverse marine life, and numerous islands.\n\nIn summary, the largest ocean in the world is the Pacific Ocean.","reasoning_content":null,"tool_calls":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"total_tokens":127,"completion_tokens":116,"prompt_tokens_details":{"cached_tokens":0}}}
{"id":"chatcmpl-20f1210d-6943-4110-ad2d-c76ba11604ad","object":"chat.completion","created":1751621261,"model":"default","choices":[{"index":0,"message":{"role":"assistant","content":"Beijing is the capital city of the People's Republic of China, located in the northern part of the country. It is situated in the North China Plain, bordered by the mountains to the west, north, and northeast. Beijing serves as China's political, cultural, and international exchange center, playing a crucial role in the nation's development and global interactions.","reasoning_content":null,"tool_calls":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"total_tokens":88,"completion_tokens":77,"prompt_tokens_details":{"cached_tokens":0}}}
```
### 4. 精度测试:(这将花费您 60~180min 时间)
准备精度脚本 ```bench_gsm8k.py``` 置于 ```/home/workspace/benchmark/``` ,并修改采样参数,如:
@@ -119,10 +125,10 @@ data = {
执行以下命令启动精度测试
```bash
cd /home/workspace/benchmark/
python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 2
python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 8
```
执行成功运行后,当前目录可以查看到精度结果的生成,文件为 ```result.jsonl```,样例如下(部分数据集,仅示例)
执行成功运行后,当前目录可以查看到精度结果的生成,文件为 ```result.jsonl```,样例如下
```json
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 365.548, "accuracy": 0.967, "num_requests": 30, "other": {"num_questions": 30, "parallel": 2}}
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}}
```

View File

@@ -19,7 +19,7 @@ from typing import Optional
import paddle
from paddle import nn
from paddle.incubate.nn.functional import fused_bias_act
from paddle.incubate.nn.functional import fused_bias_act, swiglu
from fastdeploy.config import FDConfig
from fastdeploy.platforms import current_platform
@@ -66,6 +66,8 @@ class SiluAndMul(nn.Layer):
if current_platform.is_cuda() or current_platform.is_xpu(
) or current_platform.is_iluvatar():
self.forward = self.forward_cuda
elif current_platform.is_gcu():
self.forward = self.forward_gcu
else:
raise NotImplementedError
@@ -123,3 +125,18 @@ class SiluAndMul(nn.Layer):
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
def forward_gcu(self, x):
"""
Forward propagation of the custom activation layer.
Args:
x (Tensor): Input tensor to the activation layer.
Returns:
Tensor: Output tensor.
"""
out = swiglu(x)
if self.bias is not None:
out = out + self.bias
return out

View File

@@ -16,14 +16,24 @@
all backends methods
"""
from .xpu import *
from .npu import *
from fastdeploy.platforms import current_platform
__all__ = []
from . import npu
if hasattr(npu, '__all__'):
__all__.extend(npu.__all__)
from . import xpu
if hasattr(xpu, '__all__'):
__all__.extend(xpu.__all__)
if current_platform.is_xpu():
from . import xpu
from .xpu import *
if hasattr(xpu, '__all__'):
__all__.extend(xpu.__all__)
if current_platform.is_npu():
from . import npu
from .npu import *
if hasattr(npu, '__all__'):
__all__.extend(npu.__all__)
if current_platform.is_gcu():
from . import gcu
from .gcu import *
if hasattr(gcu, '__all__'):
__all__.extend(gcu.__all__)

View File

@@ -0,0 +1,31 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
gcu backend methods
"""
from .attention.flash_attn_backend import GCUFlashAttnBackend
from .attention.mem_efficient_attn_backend import GCUMemEfficientAttnBackend
from .moe.fused_moe_method_gcu_backend import (GCUFusedMoeMethod,
GCUWeightOnlyMoEMethod)
from .quantization.weight_only import GCUWeightOnlyLinearMethod
__all__ = [
'GCUFlashAttnBackend',
'GCUMemEfficientAttnBackend',
'GCUFusedMoeMethod',
'GCUWeightOnlyMoEMethod',
'GCUWeightOnlyLinearMethod',
]

View File

@@ -0,0 +1,21 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .flash_attn_backend import GCUFlashAttnBackend
from .mem_efficient_attn_backend import GCUMemEfficientAttnBackend
__all__ = [
"GCUFlashAttnBackend",
"GCUMemEfficientAttnBackend",
]

View File

@@ -0,0 +1,287 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional
import paddle
import numpy as np
if TYPE_CHECKING:
from paddle._typing.dtype_like import _DTypeLiteral
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.worker.forward_meta import ForwardMeta, ForwardMode
from fastdeploy.model_executor.ops.gcu import (fused_rotary_embedding,
mem_efficient_attention,
flash_attn_var_len)
from paddleformers.utils.log import logger
@dataclass
class GCUFlashAttnMetadata(AttentionMetadata):
"""
GCUFlashAttnMetadata
"""
forward_mode: ForwardMode = ForwardMode.MIXED
_dtype: _DTypeLiteral = paddle.bfloat16
seq_lens_encoder: Optional[paddle.Tensor] = None
seq_lens_decoder: Optional[paddle.Tensor] = None
seq_lens_this_time: Optional[paddle.Tensor] = None
cum_offsets: Optional[paddle.Tensor] = None
padding_offset: Optional[paddle.Tensor] = None
cu_seqlens_q: Optional[paddle.Tensor] = None
cu_seqlens_k: Optional[paddle.Tensor] = None
caches: Optional[paddle.Tensor] = None
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
pre_caches_length: int = 0
class GCUFlashAttnBackend(AttentionBackend):
"""
GCUFlashAttnBackend backend implementation.
"""
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
head_dim: int):
"""
GCUFlashAttnBackend __init__
"""
super().__init__()
self.attention_metadata: GCUFlashAttnMetadata = None
self.block_size = fd_config.parallel_config.block_size
self.max_seq_len = fd_config.parallel_config.max_model_len
self.max_num_seqs = fd_config.parallel_config.max_num_seqs
self.causal = getattr(fd_config.model_config, "causal", True)
self.rank = fd_config.parallel_config.tensor_parallel_rank
self.kv_num_heads = kv_num_heads
self.num_heads = num_heads
self.head_dim = head_dim
self.scaling = 1.0 / (self.head_dim**0.5)
self.num_layers = fd_config.model_config.num_layers
self.position_ids_base = paddle.arange(self.max_seq_len)
# TODO(zhengjun): Need to adapt the allocation logic and
# temporarily allocate according to fixed size
self.all_block_tables: List[List[int]] = None
self.all_slot_mapping: List[List[int]] = None
self.rotary_embs = None
self.enable_monitor: bool = bool(os.getenv("FD_GCU_ATTN_MONITOR", False))
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
metadata = GCUFlashAttnMetadata()
metadata.forward_mode = forward_meta.forward_mode
metadata._dtype = paddle.get_default_dtype()
metadata.seq_lens_encoder = forward_meta.seq_lens_encoder
metadata.seq_lens_decoder = forward_meta.seq_lens_decoder
metadata.seq_lens_this_time = forward_meta.seq_lens_this_time
metadata.cum_offsets = forward_meta.cum_offsets
metadata.padding_offset = forward_meta.padding_offset
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
metadata.cu_seqlens_k = forward_meta.cu_seqlens_k
metadata.caches = forward_meta.caches
# metadata.block_tables = forward_meta.block_tables
metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask # not init
metadata.pre_caches_length = forward_meta.pre_caches_length # not inited
self.attention_metadata = metadata
if self.rotary_embs is None:
self.rotary_embs = metadata.rotary_embs.reshape((-1, self.head_dim))
# some info for attention
self.seq_lens_this_time_list = forward_meta.seq_lens_this_time.tolist() # List[int]
self.seq_lens_encoder_list = forward_meta.seq_lens_encoder.tolist() # List[List[int]]
self.seq_lens_decoder_list = forward_meta.seq_lens_decoder.tolist() # List[List[int]]
self.seq_lens_sum = np.sum(self.seq_lens_this_time_list)
self.max_seq_len_this_time = np.max(self.seq_lens_this_time_list)
num_seqs = forward_meta.seq_lens_this_time.shape[0]
self.is_decoder = all(x[0] == 0 for x in self.seq_lens_encoder_list)
self.is_all_prefill = all(x[0] == 0 for x in self.seq_lens_decoder_list)
# block_tables and slot_mapping
if self.all_slot_mapping is None:
max_num_blocks_per_seq = (self.max_seq_len + self.block_size - 1) // self.block_size
total_blocks = max_num_blocks_per_seq * self.max_num_seqs
self.all_block_tables = np.arange(0, total_blocks, dtype=np.int32).reshape((self.max_num_seqs, max_num_blocks_per_seq)).tolist()
self.all_slot_mapping = np.arange(0, total_blocks * self.block_size, dtype=np.int32).reshape((self.max_num_seqs, -1)).tolist()
block_tables = []
slot_mapping = []
cache_slot_range = []
cache_lens = []
position_ids = []
for seq_idx in range(num_seqs):
cache_len = None
if self.seq_lens_encoder_list[seq_idx][0] != 0: # prefill
cache_len = 0
elif self.seq_lens_decoder_list[seq_idx][0] != 0: # decode
cache_len = self.seq_lens_decoder_list[seq_idx][0]
# else: doesnot have req in this seq_idx
if cache_len is not None:
lens_this_time = self.seq_lens_this_time_list[seq_idx]
start = cache_len
end = start + lens_this_time
slot_mapping.extend(self.all_slot_mapping[seq_idx][start:end])
cache_slot_range.extend(self.all_slot_mapping[seq_idx][0:end])
cache_lens.append(end)
block_tables.append(self.all_block_tables[seq_idx])
position_ids.extend(self.position_ids_base[start:end])
self.block_tables = paddle.to_tensor(block_tables, dtype="int32")
self.slot_mapping = paddle.to_tensor(slot_mapping, dtype="int32")
self.cache_slot_range = paddle.to_tensor(cache_slot_range, dtype="int32")
self.position_ids = paddle.to_tensor(position_ids, dtype="int32")
self.position_ids = self.position_ids.reshape_((1, -1))
if self.enable_monitor:
logger.info(f"[FD_DEBUG] init_attention_metadata, position_ids:\n{self.position_ids}")
cu_query_lens_data = [0]
for seq_idx in range(num_seqs):
if self.seq_lens_this_time_list[seq_idx] != 0:
cu_query_lens_data.append(self.seq_lens_this_time_list[seq_idx])
cu_query_lens = np.array(cu_query_lens_data, dtype=np.int32).cumsum(axis=0)
self.cu_query_lens = paddle.to_tensor(cu_query_lens, dtype="int32")
self.seqused_k = paddle.to_tensor(cache_lens, dtype="int32")
self.max_seqlen_q = self.max_seq_len_this_time
self.max_seqlen_k = np.max(cache_lens)
def get_attntion_meta(self):
"""get_attntion_meta"""
return self.attention_metadata
def get_kv_cache_shape(
self,
max_num_blocks: int,
):
"""
Caculate kv cache shape
"""
# [total_tokens, kv_num_heads, head_dim]
return (max_num_blocks * self.block_size,
self.kv_num_heads,
self.head_dim)
@paddle.no_grad()
def forward_mixed(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""Run a forward for mixed."""
token_num = qkv.shape[0]
q_size = self.num_heads * self.head_dim
kv_size = self.kv_num_heads * self.head_dim
num_or_sections = [q_size, kv_size, kv_size]
query, key, value = paddle.split(qkv, num_or_sections=num_or_sections, axis=-1)
query = query.reshape_((1, -1, self.num_heads, self.head_dim))
key = key.reshape_((1, -1, self.kv_num_heads, self.head_dim))
# 1. Rope
if self.rotary_embs.dtype != query.dtype:
self.rotary_embs = paddle.cast(self.rotary_embs, query.dtype)
query, key = fused_rotary_embedding(
query,
key,
self.rotary_embs,
self.position_ids,
layer.use_neox_rotary_style
)
# 2. Save kv cache
# shape: [total_tokens, kv_num_heads, head_dim]
key = key.reshape_((-1, self.kv_num_heads, self.head_dim))
value = value.reshape_((-1, self.kv_num_heads, self.head_dim))
key_caches = forward_meta.caches[2 * layer.layer_id]
value_caches = forward_meta.caches[2 * layer.layer_id + 1]
key_caches[self.slot_mapping, :, :] = key
value_caches[self.slot_mapping, :, :] = value
# 3. calc attn
query = query.reshape_((-1, self.num_heads, self.head_dim))
key_caches = key_caches.reshape((-1, self.block_size, self.kv_num_heads, self.head_dim))
value_caches = value_caches.reshape((-1, self.block_size, self.kv_num_heads, self.head_dim))
res = flash_attn_var_len(
query=query,
key=key_caches,
value=value_caches,
cu_seqlens_q=self.cu_query_lens,
cu_seqlens_k=None,
seqused_k=self.seqused_k,
leftpad_k=None,
block_table=self.block_tables,
alibi_slopes=None,
max_seqlen_q=self.max_seqlen_q,
max_seqlen_k=self.max_seqlen_k,
p_dropout=0.0,
softmax_scale=self.scaling,
zero_tensors=False,
is_causal=self.causal,
window_size_left=-1,
window_size_right=-1,
softcap=0.0,
return_softmax=False,
)
res = res.reshape_((token_num, -1))
return res

View File

@@ -0,0 +1,357 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional
import paddle
import numpy as np
import math
if TYPE_CHECKING:
from paddle._typing.dtype_like import _DTypeLiteral
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.worker.forward_meta import ForwardMeta, ForwardMode
from fastdeploy.model_executor.ops.gcu import (fused_rotary_embedding,
mem_efficient_attention,
flash_attn_var_len)
from paddleformers.utils.log import logger
@dataclass
class GCUMemEfficientAttnMetadata(AttentionMetadata):
"""
GCUMemEfficientAttnMetadata
"""
forward_mode: ForwardMode = ForwardMode.MIXED
_dtype: _DTypeLiteral = paddle.bfloat16
seq_lens_encoder: Optional[paddle.Tensor] = None
seq_lens_decoder: Optional[paddle.Tensor] = None
seq_lens_this_time: Optional[paddle.Tensor] = None
cum_offsets: Optional[paddle.Tensor] = None
padding_offset: Optional[paddle.Tensor] = None
cu_seqlens_q: Optional[paddle.Tensor] = None
cu_seqlens_k: Optional[paddle.Tensor] = None
caches: Optional[paddle.Tensor] = None
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
pre_caches_length: int = 0
class GCUMemEfficientAttnBackend(AttentionBackend):
"""
GCUMemEfficientAttnBackend backend implementation.
"""
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
head_dim: int):
"""
GCUMemEfficientAttnBackend __init__
"""
super().__init__()
self.attention_metadata: GCUMemEfficientAttnMetadata = None
self.block_size = fd_config.parallel_config.block_size
self.max_seq_len = fd_config.parallel_config.max_model_len
self.max_num_seqs = fd_config.parallel_config.max_num_seqs
self.causal = getattr(fd_config.model_config, "causal", True)
self.rank = fd_config.parallel_config.tensor_parallel_rank
self.kv_num_heads = kv_num_heads
self.num_heads = num_heads
self.head_dim = head_dim
self.scaling = 1.0 / (self.head_dim**0.5)
self.num_layers = fd_config.model_config.num_layers
self.position_ids_base = paddle.arange(self.max_seq_len)
# TODO(zhengjun): Need to adapt the allocation logic and
# temporarily allocate according to fixed size
self.all_block_tables: List[List[int]] = None
self.all_slot_mapping: List[List[int]] = None
self.rotary_embs = None
self.use_paddle_native_sdpa = False
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
metadata = GCUMemEfficientAttnMetadata()
metadata.forward_mode = forward_meta.forward_mode
metadata._dtype = paddle.get_default_dtype()
metadata.seq_lens_encoder = forward_meta.seq_lens_encoder
metadata.seq_lens_decoder = forward_meta.seq_lens_decoder
metadata.seq_lens_this_time = forward_meta.seq_lens_this_time
metadata.cum_offsets = forward_meta.cum_offsets
metadata.padding_offset = forward_meta.padding_offset
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
metadata.cu_seqlens_k = forward_meta.cu_seqlens_k
metadata.caches = forward_meta.caches
# metadata.block_tables = forward_meta.block_tables
metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask # not init
metadata.pre_caches_length = forward_meta.pre_caches_length # not inited
self.attention_metadata = metadata
if self.rotary_embs is None:
self.rotary_embs = metadata.rotary_embs.reshape((-1, self.head_dim))
# some info for attention
self.seq_lens_this_time_list = forward_meta.seq_lens_this_time.tolist() # List[int]
self.seq_lens_encoder_list = forward_meta.seq_lens_encoder.tolist() # List[List[int]]
self.seq_lens_decoder_list = forward_meta.seq_lens_decoder.tolist() # List[List[int]]
self.seq_lens_sum = np.sum(self.seq_lens_this_time_list)
self.max_seq_len_this_time = np.max(self.seq_lens_this_time_list)
num_seqs = forward_meta.seq_lens_this_time.shape[0]
self.is_decoder = all(x[0] == 0 for x in self.seq_lens_encoder_list)
self.is_all_prefill = all(x[0] == 0 for x in self.seq_lens_decoder_list)
# block_tables and slot_mapping
if self.all_slot_mapping is None:
max_num_blocks_per_seq = (self.max_seq_len + self.block_size - 1) // self.block_size
total_blocks = max_num_blocks_per_seq * self.max_num_seqs
self.all_block_tables = np.arange(0, total_blocks, dtype=np.int32).reshape((self.max_num_seqs, max_num_blocks_per_seq)).tolist()
self.all_slot_mapping = np.arange(0, total_blocks * self.block_size, dtype=np.int32).reshape((self.max_num_seqs, -1)).tolist()
block_tables = []
slot_mapping = []
cache_slot_range = []
cache_lens = []
query_lens = []
cached_kv_lens = []
cached_kv_slot_range = []
position_ids = []
for seq_idx in range(num_seqs):
cache_len = None
if self.seq_lens_encoder_list[seq_idx][0] != 0: # prefill
cache_len = 0
elif self.seq_lens_decoder_list[seq_idx][0] != 0: # decode
cache_len = self.seq_lens_decoder_list[seq_idx][0]
# else: doesnot have req in this seq_idx
if cache_len is not None:
lens_this_time = self.seq_lens_this_time_list[seq_idx]
start = cache_len
end = start + lens_this_time
slot_mapping.extend(self.all_slot_mapping[seq_idx][start:end])
cache_slot_range.extend(self.all_slot_mapping[seq_idx][0:end])
cache_lens.append(end)
block_tables.append(self.all_block_tables[seq_idx])
position_ids.extend(self.position_ids_base[start:end])
query_lens.append(lens_this_time)
cached_kv_lens.append(end)
cached_kv_slot_range.append([self.all_slot_mapping[seq_idx][0], self.all_slot_mapping[seq_idx][end]])
self.block_tables = paddle.to_tensor(block_tables, dtype="int32")
self.slot_mapping = paddle.to_tensor(slot_mapping, dtype="int32")
self.cache_slot_range = paddle.to_tensor(cache_slot_range, dtype="int32")
self.position_ids = paddle.to_tensor(position_ids, dtype="int32")
self.position_ids = self.position_ids.reshape_((1, -1))
logger.info(f"[FD_DEBUG] init_attention_metadata, self.position_ids:\n{self.position_ids}")
cu_query_lens_data = [0]
for seq_idx in range(num_seqs):
if self.seq_lens_this_time_list[seq_idx] != 0:
cu_query_lens_data.append(self.seq_lens_this_time_list[seq_idx])
cu_query_lens = np.array(cu_query_lens_data, dtype=np.int32).cumsum(axis=0)
self.cu_query_lens = paddle.to_tensor(cu_query_lens, dtype="int32")
self.seqused_k = paddle.to_tensor(cache_lens, dtype="int32")
self.max_seqlen_q = self.max_seq_len_this_time
self.max_seqlen_k = np.max(cache_lens)
self.query_lens = query_lens
self.cached_kv_lens = cached_kv_lens
self.cached_kv_slot_range = cached_kv_slot_range
def get_attntion_meta(self):
"""get_attntion_meta"""
return self.attention_metadata
def get_kv_cache_shape(
self,
max_num_blocks: int,
):
"""
Caculate kv cache shape
"""
# [total_tokens, kv_num_heads, head_dim]
return (max_num_blocks * self.block_size,
self.kv_num_heads,
self.head_dim)
@paddle.no_grad()
def forward_mixed(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""Run a forward for mixed."""
token_num = qkv.shape[0]
q_size = self.num_heads * self.head_dim
kv_size = self.kv_num_heads * self.head_dim
num_or_sections = [q_size, kv_size, kv_size]
query, key, value = paddle.split(qkv, num_or_sections=num_or_sections, axis=-1)
query = query.reshape_((1, -1, self.num_heads, self.head_dim))
key = key.reshape_((1, -1, self.kv_num_heads, self.head_dim))
# 1. Rope
if self.rotary_embs.dtype != query.dtype:
self.rotary_embs = paddle.cast(self.rotary_embs, query.dtype)
query, key = fused_rotary_embedding(
query,
key,
self.rotary_embs,
self.position_ids,
layer.use_neox_rotary_style
)
# 2. Save kv cache
# shape: [total_tokens, kv_num_heads, head_dim]
key = key.reshape_((-1, self.kv_num_heads, self.head_dim))
value = value.reshape_((-1, self.kv_num_heads, self.head_dim))
key_caches = forward_meta.caches[2 * layer.layer_id]
value_caches = forward_meta.caches[2 * layer.layer_id + 1]
key_caches[self.slot_mapping, :, :] = key
value_caches[self.slot_mapping, :, :] = value
# 3. calc attn
query = query.reshape_((-1, self.num_heads, self.head_dim))
q_start = 0
result = paddle.empty_like(query)
for idx in range(len(self.query_lens)):
q_end = q_start + self.query_lens[idx]
kv_start = self.cached_kv_slot_range[idx][0]
kv_end = self.cached_kv_slot_range[idx][1]
q_ = query[q_start:q_end, :, :]
k_ = key_caches[kv_start:kv_end, :, :]
v_ = value_caches[kv_start:kv_end, :, :]
if self.use_paddle_native_sdpa:
res = self.native_sdpa_impl(
q_, k_, v_
)
else:
res = mem_efficient_attention(
query=q_.unsqueeze(0),
key=k_.unsqueeze(0),
value=v_.unsqueeze(0),
attn_mask=None,
dropout=0.0,
softmax_scale=self.scaling,
mask_mode=1,
seqlens=[0],
causal=self.causal,
)
result[q_start:q_end, :, :] = res
q_start = q_end
result = result.reshape_((token_num, -1))
return result
def get_triangle_upper_mask(self, shape, dtype):
# [batch_size, 1, q_seq_len, kv_seq_len]
shape[1] = 1
q_seq_len = shape[2]
kv_seq_len = shape[3]
paddle_dtype = dtype # paddle.base.data_feeder.convert_dtype(dtype)
mask = paddle.full(shape, paddle.finfo(paddle_dtype).min, dtype=paddle_dtype)
mask = paddle.triu(mask, diagonal=kv_seq_len - q_seq_len + 1)
return mask
def native_sdpa_impl(self, query, key, value):
# input shape: [num_tokens, num_heads, head_dim] -> [1, num_tokens, num_heads, head_dim]
q = query.unsqueeze(0)
k = key.unsqueeze(0)
v = value.unsqueeze(0)
batch, q_seq_len, heads, head_dim = q.shape
kv_seq_len = k.shape[1]
# [batch_size, seq_len, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim]
q = paddle.transpose(q, [0, 2, 1, 3])
k = paddle.transpose(k, [0, 2, 1, 3])
v = paddle.transpose(v, [0, 2, 1, 3])
# GQA
if q.shape[1] != k.shape[1]:
kv_head = k.shape[1]
k = k.reshape([batch, kv_head, 1, kv_seq_len, head_dim])
k = paddle.tile(k, [1, 1, heads // kv_head, 1, 1])
k = k.reshape([batch, heads, kv_seq_len, head_dim])
v = v.reshape([batch, kv_head, 1, kv_seq_len, head_dim])
v = paddle.tile(v, [1, 1, heads // kv_head, 1, 1])
v = v.reshape([batch, heads, kv_seq_len, head_dim])
# matmul and devide by sqrt(head_dim)
attn_weights = paddle.matmul(q / math.sqrt(head_dim), k.transpose([0, 1, 3, 2]))
attention_mask = self.get_triangle_upper_mask(
[batch, 1, q_seq_len, kv_seq_len], q.dtype
)
attn_weights = attn_weights + attention_mask
attn_weights = paddle.nn.functional.softmax(
attn_weights, axis=-1, dtype="float32"
).astype(q.dtype)
attn_output = paddle.matmul(attn_weights, v)
attn_output = attn_output.transpose([0, 2, 1, 3])
return attn_output.squeeze(0)

View File

@@ -0,0 +1,16 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""
gcu moe
"""

View File

@@ -0,0 +1,402 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import multiprocessing
import os
import numpy as np
import paddle
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import \
MoEMethodBase
from fastdeploy.model_executor.layers.utils import (CpuGuard,
create_and_set_parameter,
get_tensor)
from fastdeploy.model_executor.ops.gcu import (invoke_fused_moe_kernel,
moe_align_block_size,
topk_softmax,
weight_quantize_custom_rtn,
weight_quantize_rtn)
class GCUFusedMoeMethod(MoEMethodBase):
"""
Use GCU to compute Fused MoE.
"""
def __init__(self, quant_config):
super().__init__(quant_config)
self.group_size = -1
def create_weights(self, layer: nn.Layer, state_dict):
"""
Paddle gcu create weight process.
"""
# bf16
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0)
stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0)
for idx, weight_tensor in enumerate(
[stacked_ffn1_weights, stacked_ffn2_weights]):
# shape [E, K, N] -> [E, N, K]
weight_tensor = paddle.transpose(weight_tensor, [0, 2, 1])
weight_name = self.added_weight_attrs[idx]
setattr(
layer, weight_name,
layer.create_parameter(
shape=weight_tensor.shape,
dtype=weight_tensor.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
))
getattr(layer, weight_name).set_value(weight_tensor)
@paddle.no_grad()
def compute_ffn(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
enable_quant = False
) -> paddle.Tensor:
"""
Paddle gcu compute Fused MoE.
"""
token_num, hidden_size = x.shape
top_k = layer.top_k
moe_intermediate_size = layer.moe_intermediate_size
num_experts = layer.num_local_experts
topk_weights = paddle.empty([token_num, top_k], dtype=gate_out.dtype)
topk_indices = paddle.empty([token_num, top_k], dtype="int32")
token_expert_indices = paddle.empty([token_num, top_k], dtype="int32",)
topk_softmax(topk_weights, topk_indices, token_expert_indices, gate_out, norm_topk_prob=True)
config = {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
}
block_size = config["BLOCK_SIZE_M"]
max_num_tokens_padded = np.prod(topk_indices.shape) + num_experts * (block_size - 1)
max_num_m_blocks = max_num_tokens_padded // block_size
sorted_token_ids = paddle.empty([max_num_tokens_padded], dtype="int32")
expert_ids = paddle.zeros(shape=[max_num_m_blocks], dtype="int32")
num_tokens_post_pad = paddle.empty([1], dtype="int32")
sorted_token_ids, expert_ids, num_tokens_post_pad = moe_align_block_size(
sorted_token_ids,
expert_ids,
num_tokens_post_pad,
topk_indices,
num_experts,
block_size,
)
intermediate_cache1 = paddle.empty(
[token_num, top_k, moe_intermediate_size * 2],
dtype=x.dtype,
)
ffn1_B_scale = layer.moe_ffn1_weight_scale if enable_quant else None
ffn1_B_zeros = layer.moe_ffn1_weight_zeros if enable_quant else None
invoke_fused_moe_kernel(
x, # input
layer.moe_ffn1_weight, # weight
intermediate_cache1, # output
None, # A_scale
ffn1_B_scale, # B_scale
ffn1_B_zeros, # B_zp
topk_weights,
topk_indices,
sorted_token_ids,
expert_ids,
num_tokens_post_pad,
False, # mul_routed_weight
top_k,
config,
enable_quant, # use_int4_w4a16
[0, self.group_size], # block_shape
)
intermediate_cache2 = paddle.empty(
(token_num, top_k, moe_intermediate_size),
dtype=x.dtype,
)
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
intermediate_cache1)
intermediate_cache2 = intermediate_cache2.reshape([-1, moe_intermediate_size])
intermediate_cache3 = paddle.empty(
(token_num, top_k, hidden_size),
dtype=x.dtype,
)
ffn2_B_scale = layer.moe_ffn2_weight_scale if enable_quant else None
ffn2_B_zeros = layer.moe_ffn2_weight_zeros if enable_quant else None
invoke_fused_moe_kernel(
intermediate_cache2, # input
layer.moe_ffn2_weight, # weight
intermediate_cache3, # output
None, # A_scale
ffn2_B_scale, # B_scale
ffn2_B_zeros, # B_zp
topk_weights,
topk_indices,
sorted_token_ids,
expert_ids,
num_tokens_post_pad,
True, # mul_routed_weight
1,
config,
enable_quant, # use_int4_w4a16
[0, self.group_size], # block_shape
)
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
fused_moe_out = intermediate_cache3.sum(axis=1)
fused_moe_out = fused_moe_out.reshape_([token_num, hidden_size])
if layer.tp_size > 1:
from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce
tensor_model_parallel_all_reduce(fused_moe_out)
return fused_moe_out
def apply(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Paddle gcu compute Fused MoE.
"""
return self.compute_ffn(layer, x, gate_out, enable_quant=False)
def apply_ep_prefill(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
"""
raise NotImplementedError
def apply_ep_decode(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
"""
raise NotImplementedError
def apply_tp(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Paddle Cutlass compute Fused MoE.
"""
raise NotImplementedError
class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
"""
weight only for moe
"""
def __init__(self, quant_config):
super().__init__(quant_config)
self.quant_config = quant_config
self.moe_quant_type = self.quant_config.algo
self.pack_num = 1
assert self.quant_config.algo == "weight_only_int4", \
"GCUWeightOnlyMoEMethod only support weight_only_int4, but got:{self.quant_config.algo}"
self.added_qzeros_attrs = [
"moe_ffn1_weight_zeros", "moe_ffn2_weight_zeros"
]
self.group_size = 64
self.quant_multi_process_group_size = int(
os.getenv("FD_MOE_QUANT_MULTI_PROCESS_GROUP_SIZE", 8)
)
logger.info(f"GCUWeightOnlyMoEMethod quant_multi_process_group_size: {self.quant_multi_process_group_size}")
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
"""
Paddle gcu process prequanted weights.
"""
ffn1_expert_weight_key = layer.weight_key_map.get(
"ffn1_expert_weight_key", None)
ffn2_expert_weight_key = layer.weight_key_map.get(
"ffn2_expert_weight_key", None)
ffn1_expert_weight_scale_key = layer.weight_key_map.get(
"ffn1_expert_weight_scale_key", None)
ffn2_expert_weight_scale_key = layer.weight_key_map.get(
"ffn2_expert_weight_scale_key", None)
ffn1_weights, ffn2_weights = layer.load_experts_weight(
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
# self.check(layer, ffn1_weights, ffn2_weights)
ffn1_weight_scale = []
ffn2_weight_scale = []
for i in range(layer.num_experts):
expert_idx = layer.expert_id_offset + i
ffn1_weight_scale.append(
get_tensor(
state_dict.pop(
ffn1_expert_weight_scale_key.format(expert_idx))))
ffn2_weight_scale.append(
get_tensor(
state_dict.pop(
ffn2_expert_weight_scale_key.format(expert_idx))))
ffn1_weight = paddle.stack(ffn1_weights, axis=0)
ffn2_weight = paddle.stack(ffn2_weights, axis=0)
ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0)
ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0)
name_tensor_map = {
"moe_ffn1_weight": ffn1_weight,
"moe_ffn2_weight": ffn2_weight,
"moe_ffn1_weight_scale": ffn1_weight_scale,
"moe_ffn2_weight_scale": ffn2_weight_scale
}
for name, tensor in name_tensor_map.items():
create_and_set_parameter(layer, name, tensor)
@paddle.no_grad()
def create_weights(self, layer: nn.Layer, state_dict):
"""
Paddle cutlass create weight process.
"""
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, ffn1_weights, ffn2_weights)
def quant_worker(p_group_idx, shared_dict, weights, moe_quant_type, group_size):
with CpuGuard():
p_group_size = len(weights)
for group_j in range(p_group_size):
# weight shape [K, N] -> [N/2, K] -> [N, K/2]
quant_weight, scale = weight_quantize_custom_rtn(
weights[group_j],
moe_quant_type,
group_size # group_size
)
shared_dict[p_group_size * p_group_idx + group_j] = (
quant_weight, scale
)
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
weight_name = self.added_weight_attrs[idx]
scale_name = self.added_scale_attrs[idx]
zeros_name = self.added_qzeros_attrs[idx]
if self.quant_multi_process_group_size > 0:
process_group_size = self.quant_multi_process_group_size
process_group_num = layer.num_local_experts // process_group_size
grouped_weights_num = process_group_num * process_group_size
remain_weights_start_idx = grouped_weights_num
weight_list = [None] * grouped_weights_num
weight_scale_list = [None] * grouped_weights_num
with multiprocessing.Manager() as manager:
shared_dict = manager.dict({})
processes = []
for i in range(process_group_num):
w = []
for j in range(process_group_size):
w.append(weight_tensor[process_group_size * i + j].to("cpu"))
p = multiprocessing.Process(
target=quant_worker,
args=(i, shared_dict, w, self.moe_quant_type, self.group_size)
)
p.start()
processes.append(p)
for p in processes:
p.join()
dict_ = dict(shared_dict)
for k, v in dict_.items():
weight_list[k] = v[0].to(ffn1_weights[0].place)
weight_scale_list[k] = v[1].to(ffn1_weights[0].place)
else:
remain_weights_start_idx = 0
if remain_weights_start_idx < layer.num_local_experts:
for i in range(remain_weights_start_idx, layer.num_local_experts):
# weight shape [K, N] -> [N/2, K] -> [N, K/2]
quant_weight, scale = weight_quantize_rtn(
weight_tensor[i],
self.moe_quant_type,
self.group_size # group_size
)
weight_list.append(quant_weight)
weight_scale_list.append(scale)
quanted_weight = paddle.stack(weight_list, axis=0)
create_and_set_parameter(layer, weight_name, quanted_weight)
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
quanted_weight_zeros = quanted_weight_scale * 8
create_and_set_parameter(layer, zeros_name, quanted_weight_zeros)
def apply(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Paddle gcu compute Fused MoE.
"""
return self.compute_ffn(layer, x, gate_out, enable_quant=True)

View File

@@ -0,0 +1,21 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""""
gcu quantization
"""
from .weight_only import GCUWeightOnlyLinearMethod
__all__ = [
"GCUWeightOnlyLinearMethod",
]

View File

@@ -0,0 +1,90 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from fastdeploy.model_executor.layers.quantization.weight_only import (
WeightOnlyConfig, WeightOnlyLinearMethod)
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.ops.gcu import linear_quant, weight_quantize_rtn
class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
"""
Weight only quantization method for linear layer on GCU
"""
def __init__(
self,
quant_config: WeightOnlyConfig,
) -> None:
super().__init__(quant_config)
self.quant_config = quant_config
self.group_size = -1
def create_weights(self, layer):
# The scale shape should be equal to the output dim of weight using Per-Channel Quantization.
linear_weight_scale_shape = [layer.linear_weight_shape[1]]
layer.linear_weight_shape.reverse()
if self.quant_config.name() == "wint4":
layer.linear_weight_shape[0] //= 2
layer.weight_dtype = "int8"
layer.linear_weight_scale = layer.create_parameter(
shape=linear_weight_scale_shape,
dtype=layer._dtype,
is_bias=False,
)
def process_prequanted_weights(self, layer, state_dict) -> None:
"""
Process pre-quantized weights before applying them to the model
Args:
layer: The layer that owns the weights
quant_weight: The quantized weights
weight_scale: The scale of the quantized weights
"""
quant_weight = get_tensor(state_dict.pop(layer.weight_key))
weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key))
layer.linear_weight.set_value(quant_weight)
layer.linear_weight_scale.set_value(
weight_scale.astype(paddle.get_default_dtype()))
def process_loaded_weights(self, layer, weight) -> None:
quanted_weight_tensor, weight_scale_tensor = weight_quantize_rtn(
weight,
self.quant_config.algo,
self.group_size, # group_size
)
layer.linear_weight.set_value(quanted_weight_tensor)
layer.linear_weight_scale.set_value(
weight_scale_tensor.astype(paddle.get_default_dtype()))
@paddle.no_grad()
def apply(self, layer, x):
linear_out = linear_quant(
lhs=x,
rhs=layer.linear_weight,
scale=layer.linear_weight_scale,
bias=None,
group_size=self.group_size,
)
return linear_out

View File

@@ -58,7 +58,7 @@ class LinearBase(nn.Layer):
"""
super().__init__()
if current_platform.is_cuda() or current_platform.is_xpu(
) or current_platform.is_iluvatar():
) or current_platform.is_iluvatar() or current_platform.is_gcu():
self.forward = self.forward_cuda
else:
raise NotImplementedError

View File

@@ -20,6 +20,7 @@ from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.platforms import current_platform
class FusedMoE(nn.Layer):
@@ -95,8 +96,13 @@ class FusedMoE(nn.Layer):
self.moe_quant_type = moe_quant_config.name()
else:
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
from .fused_moe_cutlass_backend import CutlassMoEMethod
self.quant_method = CutlassMoEMethod(None)
if current_platform.is_cuda():
from .fused_moe_cutlass_backend import CutlassMoEMethod
self.quant_method = CutlassMoEMethod(None)
elif current_platform.is_gcu():
from fastdeploy.model_executor.layers.backends import \
GCUFusedMoeMethod
self.quant_method = GCUFusedMoeMethod(None)
if self.ep_size > 1:
self.quant_method.init_ep(self)

View File

@@ -19,9 +19,14 @@ from typing import Callable, Dict, Optional
import numpy as np
import paddle
from paddle import nn
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
from fastdeploy.platforms import current_platform
if current_platform.is_gcu():
from fastdeploy.model_executor.ops.gcu import fused_add_rms_norm, rms_norm
else:
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
from fastdeploy.config import FDConfig
from .utils import get_tensor
@@ -69,7 +74,10 @@ class RMSNorm(nn.Layer):
self.weight_key: Optional[str] = f"{prefix}.weight"
self.with_weight: bool = self.weight_key is not None
self.eps: float = eps
self.norm_func: Callable = fused_rms_norm
if current_platform.is_gcu():
self.norm_func: Callable = fused_add_rms_norm
else:
self.norm_func: Callable = fused_rms_norm
self.linear_bias: Optional[paddle.Tensor] = linear_bias
self.quant_scale: Optional[float] = quant_scale
self._dtype: str = self._helper.get_default_dtype()
@@ -129,19 +137,26 @@ class RMSNorm(nn.Layer):
The `residual_output` is the result of applying the normalization and possibly other
operations (like linear transformation) on the `residual_input`.
"""
norm_out = self.norm_func(
x,
norm_weight=self.ln_weight,
norm_bias=None,
epsilon=self.eps,
begin_norm_axis=self.begin_norm_axis,
bias=self.linear_bias,
residual=residual_input,
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
if current_platform.is_gcu():
if residual_input is None:
return rms_norm(x, self.ln_weight, self.eps)
norm_out = self.norm_func(
x, residual_input, self.ln_weight, self.eps
)
else:
norm_out = self.norm_func(
x,
norm_weight=self.ln_weight,
norm_bias=None,
epsilon=self.eps,
begin_norm_axis=self.begin_norm_axis,
bias=self.linear_bias,
residual=residual_input,
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
if residual_input is not None:
return norm_out[0], norm_out[1]
else:
@@ -193,7 +208,10 @@ class LayerNorm(nn.Layer):
self.with_bias: bool = with_bias
self.eps: float = eps
self.quant_scale: float = quant_scale
self.norm_func: Callable = fused_layer_norm
if current_platform.is_gcu():
self.norm_func: Callable = paddle.nn.functional.layer_norm
else:
self.norm_func: Callable = fused_layer_norm
self.linear_bias: Optional[paddle.Tensor] = linear_bias
self._dtype: str = self._helper.get_default_dtype()
self._norm_weight_dtype: str = "float32"
@@ -279,19 +297,40 @@ class LayerNorm(nn.Layer):
else:
raise NotImplementedError("Iluvatar does not support yet!")
norm_out = self.norm_func(
x,
norm_weight=self.ln_weight,
norm_bias=self.ln_bias,
epsilon=self.eps,
begin_norm_axis=1,
bias=self.linear_bias,
residual=residual_input,
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
if current_platform.is_gcu():
if residual_input is not None:
y = x + residual_input
out = self.norm_func(
x=y,
normalized_shape=y.shape[1:],
weight=self.ln_weight,
bias=self.linear_bias,
epsilon=self.eps,
)
return out, y
else:
out = self.norm_func(
x=x,
normalized_shape=x.shape[1:],
weight=self.ln_weight,
bias=self.linear_bias,
epsilon=self.eps,
)
return out
else:
norm_out = self.norm_func(
x,
norm_weight=self.ln_weight,
norm_bias=self.ln_bias,
epsilon=self.eps,
begin_norm_axis=1,
bias=self.linear_bias,
residual=residual_input,
quant_scale=-1 if self.quant_scale is None else self.quant_scale,
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
if residual_input is not None:
return norm_out[0], norm_out[1]
else:

View File

@@ -66,6 +66,13 @@ class WeightOnlyConfig(QuantConfigBase):
return XPUWeightOnlyMoEMethod(self)
else:
return XPUWeightOnlyLinearMethod(self)
elif current_platform.is_gcu():
from fastdeploy.model_executor.layers.backends import (
GCUWeightOnlyLinearMethod, GCUWeightOnlyMoEMethod)
if isinstance(layer, FusedMoE):
return GCUWeightOnlyMoEMethod(self)
else:
return GCUWeightOnlyLinearMethod(self)
else:
if isinstance(layer, FusedMoE):
if layer.use_method == "cutlass":

View File

@@ -55,6 +55,10 @@ class ErnieRotaryEmbedding:
dtype="float32")
emb = paddle.stack([freqs, freqs], axis=-1).reshape(
(bsz, max_seq_len, self.rotary_dim))
elif current_platform.is_gcu():
# shape: [B, S, D]
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
return rot_emb
else:
# shape: [B, S, D/2]
rot_emb = paddle.zeros(
@@ -95,6 +99,10 @@ class QwenRotaryEmbedding:
# shape: [B, S, D/2]
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"),
inv_freq)
if current_platform.is_gcu():
# shape: [B, S, D]
rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
return rot_emb
# shape: [B, S, 1, D]
emb = paddle.concat([freqs, freqs], axis=-1).reshape(
(bsz, max_seq_len, 1, self.rotary_dim))

View File

@@ -79,6 +79,21 @@ def apply_penalty_multi_scores(
min_dec_lens,
eos_token_ids,
)
elif current_platform.is_gcu():
from fastdeploy.model_executor.ops.gcu import \
get_token_penalty_multi_scores
logits = get_token_penalty_multi_scores(
pre_token_ids,
logits,
repetition_penalties,
frequency_penalties,
presence_penalties,
temperature,
bad_words_token_ids,
step_idx,
min_dec_lens,
eos_token_ids,
)
else:
raise NotImplementedError()

View File

@@ -19,7 +19,11 @@ from typing import Literal, Optional
import paddle
from fastdeploy import envs
from fastdeploy.platforms import current_platform
if current_platform.is_gcu():
from fastdeploy.model_executor.ops.gcu import \
top_p_sampling as gcu_top_p_sampling
def top_p_sampling(
x: paddle.Tensor,
@@ -46,13 +50,16 @@ def top_p_sampling(
ids = rejection_top_p_sampling(x, ps, seed)
_ = None
else:
_, ids = paddle.tensor.top_p_sampling(x,
ps,
threshold=threshold,
topp_seed=topp_seed,
seed=seed,
k=k,
mode=mode)
if current_platform.is_gcu():
_, ids = gcu_top_p_sampling(x, ps)
else:
_, ids = paddle.tensor.top_p_sampling(x,
ps,
threshold=threshold,
topp_seed=topp_seed,
seed=seed,
k=k,
mode=mode)
return _, ids

View File

@@ -171,7 +171,7 @@ class Sampler(nn.Layer):
"""
super().__init__()
if current_platform.is_cuda() or current_platform.is_xpu(
) or current_platform.is_iluvatar():
) or current_platform.is_iluvatar() or current_platform.is_gcu():
self.forward = self.forward_cuda
else:
raise NotImplementedError()

View File

@@ -17,5 +17,6 @@ from . import cpu
from . import xpu
from . import npu
from . import iluvatar
from . import gcu
__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar"]
__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar", "gcu"]

View File

@@ -0,0 +1,116 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" fastdeploy gcu ops """
from fastdeploy.platforms import current_platform
from fastdeploy.import_ops import import_custom_ops, rename_imported_op
PACKAGE = "fastdeploy.model_executor.ops.gcu"
import_custom_ops(PACKAGE, ".fastdeploy_ops", globals())
if current_platform.is_gcu():
from paddle_custom_device.gcu.ops import (invoke_fused_moe_kernel, # noqa: F401,E402
moe_align_block_size, top_p_sampling, # noqa: F401
topk_softmax, # noqa: F401
weight_quantize_custom_rtn, # noqa: F401
weight_quantize_rtn) # noqa: F401
# ###################### Ops from PaddleCustomDevice ####################
rename_imported_op(
old_name="fused_rotary_embedding_gcu",
new_name="fused_rotary_embedding",
global_ns=globals(),
)
rename_imported_op(
old_name="reshape_and_cache_gcu",
new_name="reshape_and_cache",
global_ns=globals(),
)
rename_imported_op(
old_name="paged_attention_gcu",
new_name="paged_attention",
global_ns=globals(),
)
rename_imported_op(
old_name="mem_efficient_attention_gcu",
new_name="mem_efficient_attention",
global_ns=globals(),
)
rename_imported_op(
old_name="flash_attn_var_len_gcu",
new_name="flash_attn_var_len",
global_ns=globals(),
)
rename_imported_op(
old_name="rms_norm_gcu",
new_name="rms_norm",
global_ns=globals(),
)
rename_imported_op(
old_name="fused_add_rms_norm_op",
new_name="fused_add_rms_norm",
global_ns=globals(),
)
rename_imported_op(
old_name="linear_quant_gcu",
new_name="linear_quant",
global_ns=globals(),
)
# ###################### CPU OPS ####################
rename_imported_op(
old_name="get_padding_offset_gcu",
new_name="get_padding_offset",
global_ns=globals(),
)
rename_imported_op(
old_name="update_inputs_gcu",
new_name="update_inputs",
global_ns=globals(),
)
rename_imported_op(
old_name="rebuild_padding_gcu",
new_name="rebuild_padding",
global_ns=globals(),
)
rename_imported_op(
old_name="get_token_penalty_multi_scores_gcu",
new_name="get_token_penalty_multi_scores",
global_ns=globals(),
)
rename_imported_op(
old_name="set_stop_value_multi_ends_gcu",
new_name="set_stop_value_multi_ends",
global_ns=globals(),
)
rename_imported_op(
old_name="set_value_by_flags_and_idx_gcu",
new_name="set_value_by_flags_and_idx",
global_ns=globals(),
)

View File

@@ -24,6 +24,11 @@ if current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import (
get_padding_offset, save_output, set_stop_value_multi_ends,
step_paddle, update_inputs)
elif current_platform.is_gcu():
from fastdeploy.model_executor.ops.gcu import (get_padding_offset,
save_output,
set_stop_value_multi_ends,
update_inputs)
else:
from fastdeploy.model_executor.ops.gpu import (
get_padding_offset, save_output, set_stop_value_multi_ends,
@@ -391,6 +396,17 @@ def rebuild_padding(tmp_out: paddle.Tensor,
output_padding_offset,
max_input_length,
)
elif current_platform.is_gcu():
from fastdeploy.model_executor.ops.gcu import rebuild_padding
hidden_states = rebuild_padding(
tmp_out,
cum_offsets,
seq_len_this_time,
seq_lens_decoder,
seq_lens_encoder,
output_padding_offset,
max_input_length,
)
elif current_platform.is_cpu():
from fastdeploy.model_executor.ops.cpu import rebuild_padding_cpu
hidden_states = rebuild_padding_cpu(

View File

@@ -124,6 +124,8 @@ class TokenProcessor(object):
from fastdeploy.model_executor.ops.xpu import get_output
elif current_platform.is_iluvatar():
from fastdeploy.model_executor.ops.iluvatar import get_output
elif current_platform.is_gcu():
from fastdeploy.model_executor.ops.gcu import get_output
else:
from fastdeploy.model_executor.ops.gpu import (get_output,
get_output_ep,

View File

@@ -22,6 +22,7 @@ from .xpu import XPUPlatform
from .npu import NPUPlatform
from .dcu import DCUPlatform
from .iluvatar import IluvatarPlatform
from .gcu import GCUPlatform
from .base import _Backend # noqa: F401
_current_platform = None
@@ -42,6 +43,8 @@ def __getattr__(name: str):
_current_platform = DCUPlatform()
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
_current_platform = IluvatarPlatform()
elif paddle.is_compiled_with_custom_device("gcu"):
_current_platform = GCUPlatform()
else:
_current_platform = CPUPlatform()
return _current_platform

View File

@@ -69,6 +69,12 @@ class Platform:
"""
return paddle.is_compiled_with_custom_device("iluvatar_gpu")
def is_gcu(self) -> bool:
"""
whether platform is gcu
"""
return paddle.is_compiled_with_custom_device("gcu")
@classmethod
def get_attention_backend_cls(self, selected_backend):
"""Get the attention backend"""

View File

@@ -0,0 +1,61 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from fastdeploy.utils import console_logger as logger
from .base import Platform, _Backend
class GCUPlatform(Platform):
"""
gcu platform class
"""
device_name = "gcu"
@classmethod
def available(self):
"""
Check whether GCU is available.
"""
try:
assert paddle.base.core.get_custom_device_count('gcu') > 0
return True
except Exception as e:
logger.warning(
"You are using GCUPlatform, but there is no GCU "
"detected on your machine. Maybe GCU devices is not set properly."
f"\n Original Error is {e}"
)
return False
@classmethod
def get_attention_backend_cls(cls, selected_backend: _Backend):
"""
get_attention_backend_cls
"""
if selected_backend == _Backend.NATIVE_ATTN:
logger.info("Using GCU mem_efficient ATTN backend.")
return ("fastdeploy.model_executor.layers.backends.gcu.attention.mem_efficient_attn_backend.GCUMemEfficientAttnBackend")
elif selected_backend == _Backend.APPEND_ATTN:
logger.info("Using GCU ATTN backend.")
return ("fastdeploy.model_executor.layers.backends.gcu.attention.flash_attn_backend.GCUFlashAttnBackend")
else:
raise ValueError(
"Invalid attention backend you specified.\n"
"Now only support [NATIVE_ATTN, APPEND_ATTN] in gcu place."
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,142 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import gc
from typing import List, Optional
import paddle
import paddle.nn as nn
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.utils import get_logger
from fastdeploy.worker.gcu_model_runner import GCUModelRunner
from fastdeploy.worker.output import ModelRunnerOutput
from fastdeploy.worker.worker_base import WorkerBase
logger = get_logger("gcu_worker", "gcu_worker.log")
class GcuWorker(WorkerBase):
""" """
def __init__(
self,
fd_config: FDConfig,
local_rank: int,
rank: int,
):
super().__init__(
fd_config=fd_config,
local_rank=local_rank,
rank=rank,
)
pass
def init_device(self):
""" Initialize device and Construct model runner
"""
if paddle.is_compiled_with_custom_device("gcu"):
# Set evironment variable
self.device_ids = self.parallel_config.device_ids.split(",")
self.device = f"gcu:{self.local_rank}"
paddle.device.set_device(self.device)
paddle.set_default_dtype(self.parallel_config.dtype)
logger.info(f"GcuWorker init_device:{self.device}, device_ids:{self.device_ids}")
gc.collect()
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Construct model runner
self.model_runner: GCUModelRunner = GCUModelRunner(
fd_config=self.fd_config,
device=self.device,
device_id=self.device_ids[self.local_rank],
rank=self.rank,
local_rank=self.local_rank)
def prefill_finished(self):
"""
check whether prefill stage finished
"""
return self.model_runner.prefill_finished()
def determine_available_memory(self) -> int:
"""
Profiles the peak memory usage of the model to determine how much
memory can be used for KV cache without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GCU and CPU blocks
that can be allocated with the remaining free memory.
Tip:
You may limit the usage of GCU memory
by adjusting the `gcu_memory_utilization` parameter.
"""
raise NotImplementedError
def load_model(self) -> None:
""" """
self.model_runner.load_model()
def get_model(self) -> nn.Layer:
""" """
return self.model_runner.get_model()
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
""" """
pass
def execute_model(
self,
model_forward_batch: Optional[List[Request]] = None,
) -> Optional[ModelRunnerOutput]:
""" """
output = self.model_runner.execute_model(model_forward_batch)
return output
def preprocess_new_task(self, req_dicts: List[Request]) -> None:
""" Process new requests and then start the decode loop
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
and workers and modelrunners should not perceive it.
"""
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts)
def graph_optimize_and_warm_up_model(self) -> None:
"""
Perform the warm-up and the graph optimization
"""
# 1. Warm up model
# NOTE(gongshaotian): may be not need warm_up at this place
# 2. Triger cuda grpah capture
self.model_runner.capture_model()
def check_health(self) -> bool:
""" """
return True
def cal_theortical_kvcache(self) -> int:
""" """
return self.model_runner.cal_theortical_kvcache()
def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
""" """
self.model_runner.update_share_input_block_num(
num_gpu_blocks=num_gpu_blocks)

View File

@@ -53,6 +53,9 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
return IluvatarWorker(fd_config=fd_config,
local_rank=local_rank,
rank=rank)
if current_platform.is_gcu():
from fastdeploy.worker.gcu_worker import GcuWorker
return GcuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
class PaddleDisWorkerProc():

View File

@@ -167,6 +167,8 @@ def get_device_type():
return "npu"
elif paddle.is_compiled_with_custom_device('iluvatar_gpu'):
return "iluvatar-gpu"
elif paddle.is_compiled_with_custom_device('gcu'):
return "gcu"
else:
return "cpu"
@@ -199,7 +201,7 @@ setup(
"model_executor/ops/xpu/libs/*", "model_executor/ops/npu/*",
"model_executor/ops/base/*", "model_executor/ops/iluvatar/*",
"model_executor/models/*", "model_executor/layers/*",
"input/mm_processor/utils/*",
"input/mm_processor/utils/*", "model_executor/ops/gcu/*",
"version.txt"
]
},