diff --git a/build.sh b/build.sh index 0ddc2588b..e4431eb60 100644 --- a/build.sh +++ b/build.sh @@ -113,6 +113,14 @@ function copy_ops(){ return fi + is_gcu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('gcu'))"` + if [ "$is_gcu" = "True" ]; then + DEVICE_TYPE="gcu" + cp -r ${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gcu + echo -e "gcu ops have been copy to fastdeploy" + return + fi + DEVICE_TYPE="cpu" cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base cd ../../../../ diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 75a9f4621..bb165fc88 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -501,6 +501,17 @@ elif paddle.is_compiled_with_custom_device("iluvatar_gpu"): ], ), ) +elif paddle.is_compiled_with_custom_device("gcu"): + setup( + name="fastdeploy_ops", + ext_modules=CppExtension( + sources=[ + "gpu_ops/save_with_output_msg.cc", + "gpu_ops/get_output.cc", + "gpu_ops/get_output_msg_with_topk.cc", + ] + ), + ) else: use_bf16 = envs.FD_CPU_USE_BF16 == "True" diff --git a/docs/get_started/installation/Enflame_gcu.md b/docs/get_started/installation/Enflame_gcu.md index edda97474..844c38626 100644 --- a/docs/get_started/installation/Enflame_gcu.md +++ b/docs/get_started/installation/Enflame_gcu.md @@ -1,8 +1,8 @@ -# Running ERNIE-4.5-21B-A3B with FastDeploy +# Running ERNIE 4.5 Series Models with FastDeploy The Enflame S60 ([Learn about Enflame](https://www.enflame-tech.com/)) is a next-generation AI inference accelerator card designed for large-scale deployment in data centers. It meets the demands of large language models (LLMs), search/advertising/recommendation systems, and traditional models. Characterized by broad model coverage, user-friendliness, and high portability, it is widely applicable to mainstream inference scenarios such as image and text generation applications, search and recommendation systems, and text/image/speech recognition. -FastDeploy has deeply adapted and optimized the ernie-4_5-21b-a3b-bf16-paddle model for the Enflame S60, achieving a unified inference interface between GCU and GPU. This allows seamless migration of inference tasks without code modifications. +FastDeploy has deeply adapted and optimized the ERNIE 4.5 Series Models for the Enflame S60, achieving a unified inference interface between GCU and GPU. This allows seamless migration of inference tasks without code modifications. ## 🚀 Quick Start 🚀 @@ -27,15 +27,15 @@ lspci | grep S60 3b:00.0 Processing accelerators: Shanghai Enflame Technology Co. Ltd S60 [Enflame] (rev 01) 3c:00.0 Processing accelerators: Shanghai Enflame Technology Co. Ltd S60 [Enflame] (rev 01) ``` -### 1. Environment Setup (Estimated time: 5–10 minutes) +### 1. Environment Setup (Estimated time: 5-10 minutes) 1. Pull the Docker image ```bash # Note: This image only contains the Paddle development environment, not precompiled PaddlePaddle packages -docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.4.623-ubuntu20-x86_64-gcc84 +docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84 ``` 2. Start the container ```bash -docker run --name paddle-gcu-llm -v /home:/home -v /work:/work --network=host --ipc=host -it --privileged ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.4.623-ubuntu20-x86_64-gcc84 /bin/bash +docker run --name paddle-gcu-llm -v /home:/home -v /work:/work --network=host --ipc=host -it --privileged ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84 /bin/bash ``` 3. Obtain and install drivers
**Full software packages are preloaded in the Docker container. Copy them to an external directory, e.g., ```/home/workspace/deps/```** @@ -67,25 +67,31 @@ python -m pip install paddle-custom-gcu==3.1.0 -i https://www.paddlepaddle.org.c 7. Install FastDeploy and dependencies ```bash python -m pip install fastdeploy -i https://www.paddlepaddle.org.cn/packages/stable/gcu/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simplels -apt install python3.10-distutils +# For source compilation, refer to the following steps +git clone https://github.com/PaddlePaddle/FastDeploy +cd FastDeploy +python -m pip install -r requirements.txt --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simplels +bash build.sh 1 ``` -### 2. Data Preparation (Estimated time: 2–5 minutes) +### 2. Data Preparation (Estimated time: 2-5 minutes) Use a trained model for inference on GSM8K dataset: ```bash mkdir -p /home/workspace/benchmark/ && cd /home/workspace/benchmark/ wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl ``` -Place model weights in a directory, e.g., ```/work/models/ernie-4_5-21b-a3b-bf16-paddle/``` -### 3. Inference (Estimated time: 2–5 minutes) +Place model weights in a directory, e.g., ```/work/models/ERNIE-4.5-300B-A47B-Paddle/``` +### 3. Inference (Estimated time: 2-5 minutes) Start the inference service: ```bash python -m fastdeploy.entrypoints.openai.api_server \ - --model "/work/models/ernie-4_5-21b-a3b-bf16-paddle/" \ + --model "/work/models/ERNIE-4.5-300B-A47B-Paddle/" \ --port 8188 \ --metrics-port 8200 \ - --tensor-parallel-size 4 \ - --max-model-len 8192 \ - --num-gpu-blocks-override 1024 + --tensor-parallel-size 8 \ + --max-model-len 32768 \ + --num-gpu-blocks-override 4096 \ + --max-num-batched-tokens 32768 \ + --quantization "wint4" ``` Query the model service: ```bash @@ -93,13 +99,13 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \ -H "Content-Type: application/json" \ -d '{ "messages": [ - {"role": "user", "content": "The largest ocean is"} + {"role": "user", "content": "Where is Beijing?"} ] }' ``` Successful execution returns inference results, e.g.: ```json -{"id":"chatcmpl-5cd96f3b-eff3-4dc0-8aa2-8b5d7b7b86f2","object":"chat.completion","created":1751167862,"model":"default","choices":[{"index":0,"message":{"role":"assistant","content":"3. **Pacific Ocean**: The Pacific Ocean is the largest and deepest of the world's oceans. It covers an area of approximately 181,344,000 square kilometers, which is more than 30% of the Earth's surface. It is located between the Americas to the west and east, and Asia and Australia to the north and south. The Pacific Ocean is known for its vastness, diverse marine life, and numerous islands.\n\nIn summary, the largest ocean in the world is the Pacific Ocean.","reasoning_content":null,"tool_calls":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"total_tokens":127,"completion_tokens":116,"prompt_tokens_details":{"cached_tokens":0}}} +{"id":"chatcmpl-20f1210d-6943-4110-ad2d-c76ba11604ad","object":"chat.completion","created":1751621261,"model":"default","choices":[{"index":0,"message":{"role":"assistant","content":"Beijing is the capital city of the People's Republic of China, located in the northern part of the country. It is situated in the North China Plain, bordered by the mountains to the west, north, and northeast. Beijing serves as China's political, cultural, and international exchange center, playing a crucial role in the nation's development and global interactions.","reasoning_content":null,"tool_calls":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"total_tokens":88,"completion_tokens":77,"prompt_tokens_details":{"cached_tokens":0}}} ``` ### 4. Accuracy Testing (Estimated time: 60–180 minutes) Place the accuracy script ```bench_gsm8k.py``` in ```/home/workspace/benchmark/``` and modify sampling parameters, e.g.: @@ -120,10 +126,10 @@ data = { Run accuracy tests: ```bash cd /home/workspace/benchmark/ -python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 2 +python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 8 ``` Upon completion, accuracy results are saved in ```result.jsonl```, e.g.: ```json -{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 365.548, "accuracy": 0.967, "num_requests": 30, "other": {"num_questions": 30, "parallel": 2}} +{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}} ``` diff --git a/docs/zh/get_started/installation/Enflame_gcu.md b/docs/zh/get_started/installation/Enflame_gcu.md index c5ca47009..f47212dc6 100644 --- a/docs/zh/get_started/installation/Enflame_gcu.md +++ b/docs/zh/get_started/installation/Enflame_gcu.md @@ -1,8 +1,8 @@ -# 使用 FastDeploy 在燧原 S60 上运行 ERNIE-4.5-21B-A3B模型 +# 使用 FastDeploy 在燧原 S60 上运行 ERNIE 4.5 系列模型 燧原 S60([了解燧原](https://www.enflame-tech.com/))是面向数据中心大规模部署的新一代人工智能推理加速卡,满足大语言模型、搜广推及传统模型的需求,具有模型覆盖面广、易用性强、易迁移易部署等特点,可广泛应用于图像及文本生成等应用、搜索与推荐、文本、图像及语音识别等主流推理场景。 -FastDeploy 在燧原 S60 上对 ernie-4_5-21b-a3b-bf16-paddle 模型进行了深度适配和优化,实现了 GCU 推理入口和 GPU 的统一,无需修改即可完成推理任务的迁移。 +FastDeploy 在燧原 S60 上对 ERNIE 4.5 系列模型进行了深度适配和优化,实现了 GCU 推理入口和 GPU 的统一,无需修改即可完成推理任务的迁移。 ## 🚀 快速开始 🚀 @@ -30,11 +30,11 @@ lspci | grep S60 1. 拉取镜像 ```bash # 注意此镜像仅为paddle开发环境,镜像中不包含预编译的飞桨安装包 -docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.4.623-ubuntu20-x86_64-gcc84 +docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84 ``` 2. 参考如下命令启动容器 ```bash -docker run --name paddle-gcu-llm -v /home:/home -v /work:/work --network=host --ipc=host -it --privileged ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.4.623-ubuntu20-x86_64-gcc84 /bin/bash +docker run --name paddle-gcu-llm -v /home:/home -v /work:/work --network=host --ipc=host -it --privileged ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84 /bin/bash ``` 3. 获取并安装驱动
**docker 内提前放置了全量软件包,需拷贝至 docker 外目录,如:```/home/workspace/deps/```** @@ -63,10 +63,14 @@ python -m pip install paddlepaddle==3.1.0a0 -i https://www.paddlepaddle.org.cn/p python -m pip install paddle-custom-gcu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/gcu/ # 如想源码编译安装,请参考https://github.com/PaddlePaddle/PaddleCustomDevice/blob/develop/backends/gcu/README_cn.md ``` -7. 安装 FastDeploy 和 依赖
+7. 安装 FastDeploy
```bash python -m pip install fastdeploy -i https://www.paddlepaddle.org.cn/packages/stable/gcu/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simplels -apt install python3.10-distutils +# 如想源码编译安装,请参考如下步骤 +git clone https://github.com/PaddlePaddle/FastDeploy +cd FastDeploy +python -m pip install -r requirements.txt --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simplels +bash build.sh 1 ``` ### 2. 数据准备:(这将花费您 2~5min 时间) 使用训练好的模型,在 GSM8K 上推理 @@ -74,17 +78,19 @@ apt install python3.10-distutils mkdir -p /home/workspace/benchmark/ && cd /home/workspace/benchmark/ wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl ``` -准备模型和权重,置于环境目录,如:```/work/models/ernie-4_5-21b-a3b-bf16-paddle/``` +准备模型和权重,置于环境目录,如:```/work/models/ERNIE-4.5-300B-A47B-Paddle/``` ### 3. 推理:(这将花费您 2~5min 时间) 执行如下命令启动推理服务 ```bash python -m fastdeploy.entrypoints.openai.api_server \ - --model "/work/models/ernie-4_5-21b-a3b-bf16-paddle/" \ + --model "/work/models/ERNIE-4.5-300B-A47B-Paddle/" \ --port 8188 \ --metrics-port 8200 \ - --tensor-parallel-size 4 \ - --max-model-len 8192 \ - --num-gpu-blocks-override 1024 + --tensor-parallel-size 8 \ + --max-model-len 32768 \ + --num-gpu-blocks-override 4096 \ + --max-num-batched-tokens 32768 \ + --quantization "wint4" ``` 使用如下命令请求模型服务 ```bash @@ -92,13 +98,13 @@ curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \ -H "Content-Type: application/json" \ -d '{ "messages": [ - {"role": "user", "content": "The largest ocean is"} + {"role": "user", "content": "Where is Beijing?"} ] }' ``` 成功运行后,可以查看到推理结果的生成,样例如下 ```json -{"id":"chatcmpl-5cd96f3b-eff3-4dc0-8aa2-8b5d7b7b86f2","object":"chat.completion","created":1751167862,"model":"default","choices":[{"index":0,"message":{"role":"assistant","content":"3. **Pacific Ocean**: The Pacific Ocean is the largest and deepest of the world's oceans. It covers an area of approximately 181,344,000 square kilometers, which is more than 30% of the Earth's surface. It is located between the Americas to the west and east, and Asia and Australia to the north and south. The Pacific Ocean is known for its vastness, diverse marine life, and numerous islands.\n\nIn summary, the largest ocean in the world is the Pacific Ocean.","reasoning_content":null,"tool_calls":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"total_tokens":127,"completion_tokens":116,"prompt_tokens_details":{"cached_tokens":0}}} +{"id":"chatcmpl-20f1210d-6943-4110-ad2d-c76ba11604ad","object":"chat.completion","created":1751621261,"model":"default","choices":[{"index":0,"message":{"role":"assistant","content":"Beijing is the capital city of the People's Republic of China, located in the northern part of the country. It is situated in the North China Plain, bordered by the mountains to the west, north, and northeast. Beijing serves as China's political, cultural, and international exchange center, playing a crucial role in the nation's development and global interactions.","reasoning_content":null,"tool_calls":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"total_tokens":88,"completion_tokens":77,"prompt_tokens_details":{"cached_tokens":0}}} ``` ### 4. 精度测试:(这将花费您 60~180min 时间) 准备精度脚本 ```bench_gsm8k.py``` 置于 ```/home/workspace/benchmark/``` ,并修改采样参数,如: @@ -119,10 +125,10 @@ data = { 执行以下命令启动精度测试 ```bash cd /home/workspace/benchmark/ -python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 2 +python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 8 ``` -执行成功运行后,当前目录可以查看到精度结果的生成,文件为 ```result.jsonl```,样例如下(部分数据集,仅示例) +执行成功运行后,当前目录可以查看到精度结果的生成,文件为 ```result.jsonl```,样例如下 ```json -{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 365.548, "accuracy": 0.967, "num_requests": 30, "other": {"num_questions": 30, "parallel": 2}} +{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}} ``` diff --git a/fastdeploy/model_executor/layers/activation.py b/fastdeploy/model_executor/layers/activation.py index 09126aa6c..5f7a568ff 100644 --- a/fastdeploy/model_executor/layers/activation.py +++ b/fastdeploy/model_executor/layers/activation.py @@ -19,7 +19,7 @@ from typing import Optional import paddle from paddle import nn -from paddle.incubate.nn.functional import fused_bias_act +from paddle.incubate.nn.functional import fused_bias_act, swiglu from fastdeploy.config import FDConfig from fastdeploy.platforms import current_platform @@ -66,6 +66,8 @@ class SiluAndMul(nn.Layer): if current_platform.is_cuda() or current_platform.is_xpu( ) or current_platform.is_iluvatar(): self.forward = self.forward_cuda + elif current_platform.is_gcu(): + self.forward = self.forward_gcu else: raise NotImplementedError @@ -123,3 +125,18 @@ class SiluAndMul(nn.Layer): quant_max_bound=self.quant_max_bound, quant_min_bound=self.quant_min_bound, ) + + def forward_gcu(self, x): + """ + Forward propagation of the custom activation layer. + + Args: + x (Tensor): Input tensor to the activation layer. + + Returns: + Tensor: Output tensor. + """ + out = swiglu(x) + if self.bias is not None: + out = out + self.bias + return out diff --git a/fastdeploy/model_executor/layers/backends/__init__.py b/fastdeploy/model_executor/layers/backends/__init__.py index d3ccd6a0d..fbb12bd79 100644 --- a/fastdeploy/model_executor/layers/backends/__init__.py +++ b/fastdeploy/model_executor/layers/backends/__init__.py @@ -16,14 +16,24 @@ all backends methods """ -from .xpu import * -from .npu import * +from fastdeploy.platforms import current_platform __all__ = [] -from . import npu -if hasattr(npu, '__all__'): - __all__.extend(npu.__all__) - -from . import xpu -if hasattr(xpu, '__all__'): - __all__.extend(xpu.__all__) \ No newline at end of file + +if current_platform.is_xpu(): + from . import xpu + from .xpu import * + if hasattr(xpu, '__all__'): + __all__.extend(xpu.__all__) + +if current_platform.is_npu(): + from . import npu + from .npu import * + if hasattr(npu, '__all__'): + __all__.extend(npu.__all__) + +if current_platform.is_gcu(): + from . import gcu + from .gcu import * + if hasattr(gcu, '__all__'): + __all__.extend(gcu.__all__) diff --git a/fastdeploy/model_executor/layers/backends/gcu/__init__.py b/fastdeploy/model_executor/layers/backends/gcu/__init__.py new file mode 100644 index 000000000..8de8fe8d8 --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/gcu/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +gcu backend methods +""" + +from .attention.flash_attn_backend import GCUFlashAttnBackend +from .attention.mem_efficient_attn_backend import GCUMemEfficientAttnBackend +from .moe.fused_moe_method_gcu_backend import (GCUFusedMoeMethod, + GCUWeightOnlyMoEMethod) +from .quantization.weight_only import GCUWeightOnlyLinearMethod + +__all__ = [ + 'GCUFlashAttnBackend', + 'GCUMemEfficientAttnBackend', + 'GCUFusedMoeMethod', + 'GCUWeightOnlyMoEMethod', + 'GCUWeightOnlyLinearMethod', +] diff --git a/fastdeploy/model_executor/layers/backends/gcu/attention/__init__.py b/fastdeploy/model_executor/layers/backends/gcu/attention/__init__.py new file mode 100644 index 000000000..59e299f61 --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/gcu/attention/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .flash_attn_backend import GCUFlashAttnBackend +from .mem_efficient_attn_backend import GCUMemEfficientAttnBackend + +__all__ = [ + "GCUFlashAttnBackend", + "GCUMemEfficientAttnBackend", +] diff --git a/fastdeploy/model_executor/layers/backends/gcu/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/backends/gcu/attention/flash_attn_backend.py new file mode 100644 index 000000000..56870de82 --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/gcu/attention/flash_attn_backend.py @@ -0,0 +1,287 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional + +import paddle + +import numpy as np + + +if TYPE_CHECKING: + from paddle._typing.dtype_like import _DTypeLiteral + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.layers.attention.attention import Attention +from fastdeploy.model_executor.layers.attention.base_attention_backend import ( + AttentionBackend, AttentionMetadata) +from fastdeploy.worker.forward_meta import ForwardMeta, ForwardMode + +from fastdeploy.model_executor.ops.gcu import (fused_rotary_embedding, + mem_efficient_attention, + flash_attn_var_len) +from paddleformers.utils.log import logger + + +@dataclass +class GCUFlashAttnMetadata(AttentionMetadata): + """ + GCUFlashAttnMetadata + """ + forward_mode: ForwardMode = ForwardMode.MIXED + + _dtype: _DTypeLiteral = paddle.bfloat16 + + seq_lens_encoder: Optional[paddle.Tensor] = None + seq_lens_decoder: Optional[paddle.Tensor] = None + seq_lens_this_time: Optional[paddle.Tensor] = None + cum_offsets: Optional[paddle.Tensor] = None + padding_offset: Optional[paddle.Tensor] = None + + cu_seqlens_q: Optional[paddle.Tensor] = None + cu_seqlens_k: Optional[paddle.Tensor] = None + caches: Optional[paddle.Tensor] = None + + block_tables: Optional[paddle.Tensor] = None + rotary_embs: Optional[paddle.Tensor] = None + attn_mask: Optional[paddle.Tensor] = None + + pre_caches_length: int = 0 + + + + +class GCUFlashAttnBackend(AttentionBackend): + """ + GCUFlashAttnBackend backend implementation. + """ + + def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, + head_dim: int): + """ + GCUFlashAttnBackend __init__ + """ + super().__init__() + self.attention_metadata: GCUFlashAttnMetadata = None + self.block_size = fd_config.parallel_config.block_size + self.max_seq_len = fd_config.parallel_config.max_model_len + self.max_num_seqs = fd_config.parallel_config.max_num_seqs + + self.causal = getattr(fd_config.model_config, "causal", True) + + self.rank = fd_config.parallel_config.tensor_parallel_rank + self.kv_num_heads = kv_num_heads + self.num_heads = num_heads + self.head_dim = head_dim + self.scaling = 1.0 / (self.head_dim**0.5) + self.num_layers = fd_config.model_config.num_layers + self.position_ids_base = paddle.arange(self.max_seq_len) + + # TODO(zhengjun): Need to adapt the allocation logic and + # temporarily allocate according to fixed size + self.all_block_tables: List[List[int]] = None + self.all_slot_mapping: List[List[int]] = None + + self.rotary_embs = None + self.enable_monitor: bool = bool(os.getenv("FD_GCU_ATTN_MONITOR", False)) + + + + def init_attention_metadata(self, forward_meta: ForwardMeta): + """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" + metadata = GCUFlashAttnMetadata() + + metadata.forward_mode = forward_meta.forward_mode + + metadata._dtype = paddle.get_default_dtype() + + metadata.seq_lens_encoder = forward_meta.seq_lens_encoder + metadata.seq_lens_decoder = forward_meta.seq_lens_decoder + metadata.seq_lens_this_time = forward_meta.seq_lens_this_time + metadata.cum_offsets = forward_meta.cum_offsets + metadata.padding_offset = forward_meta.padding_offset + + metadata.cu_seqlens_q = forward_meta.cu_seqlens_q + metadata.cu_seqlens_k = forward_meta.cu_seqlens_k + metadata.caches = forward_meta.caches + + # metadata.block_tables = forward_meta.block_tables + metadata.rotary_embs = forward_meta.rotary_embs + metadata.attn_mask = forward_meta.attn_mask # not init + + metadata.pre_caches_length = forward_meta.pre_caches_length # not inited + + self.attention_metadata = metadata + + if self.rotary_embs is None: + self.rotary_embs = metadata.rotary_embs.reshape((-1, self.head_dim)) + + # some info for attention + self.seq_lens_this_time_list = forward_meta.seq_lens_this_time.tolist() # List[int] + self.seq_lens_encoder_list = forward_meta.seq_lens_encoder.tolist() # List[List[int]] + self.seq_lens_decoder_list = forward_meta.seq_lens_decoder.tolist() # List[List[int]] + self.seq_lens_sum = np.sum(self.seq_lens_this_time_list) + self.max_seq_len_this_time = np.max(self.seq_lens_this_time_list) + + num_seqs = forward_meta.seq_lens_this_time.shape[0] + + + self.is_decoder = all(x[0] == 0 for x in self.seq_lens_encoder_list) + self.is_all_prefill = all(x[0] == 0 for x in self.seq_lens_decoder_list) + + # block_tables and slot_mapping + if self.all_slot_mapping is None: + max_num_blocks_per_seq = (self.max_seq_len + self.block_size - 1) // self.block_size + total_blocks = max_num_blocks_per_seq * self.max_num_seqs + self.all_block_tables = np.arange(0, total_blocks, dtype=np.int32).reshape((self.max_num_seqs, max_num_blocks_per_seq)).tolist() + self.all_slot_mapping = np.arange(0, total_blocks * self.block_size, dtype=np.int32).reshape((self.max_num_seqs, -1)).tolist() + + block_tables = [] + slot_mapping = [] + cache_slot_range = [] + cache_lens = [] + position_ids = [] + for seq_idx in range(num_seqs): + cache_len = None + if self.seq_lens_encoder_list[seq_idx][0] != 0: # prefill + cache_len = 0 + elif self.seq_lens_decoder_list[seq_idx][0] != 0: # decode + cache_len = self.seq_lens_decoder_list[seq_idx][0] + # else: doesnot have req in this seq_idx + + if cache_len is not None: + lens_this_time = self.seq_lens_this_time_list[seq_idx] + start = cache_len + end = start + lens_this_time + slot_mapping.extend(self.all_slot_mapping[seq_idx][start:end]) + cache_slot_range.extend(self.all_slot_mapping[seq_idx][0:end]) + cache_lens.append(end) + block_tables.append(self.all_block_tables[seq_idx]) + position_ids.extend(self.position_ids_base[start:end]) + + self.block_tables = paddle.to_tensor(block_tables, dtype="int32") + self.slot_mapping = paddle.to_tensor(slot_mapping, dtype="int32") + self.cache_slot_range = paddle.to_tensor(cache_slot_range, dtype="int32") + self.position_ids = paddle.to_tensor(position_ids, dtype="int32") + self.position_ids = self.position_ids.reshape_((1, -1)) + + if self.enable_monitor: + logger.info(f"[FD_DEBUG] init_attention_metadata, position_ids:\n{self.position_ids}") + + cu_query_lens_data = [0] + for seq_idx in range(num_seqs): + if self.seq_lens_this_time_list[seq_idx] != 0: + cu_query_lens_data.append(self.seq_lens_this_time_list[seq_idx]) + cu_query_lens = np.array(cu_query_lens_data, dtype=np.int32).cumsum(axis=0) + + self.cu_query_lens = paddle.to_tensor(cu_query_lens, dtype="int32") + self.seqused_k = paddle.to_tensor(cache_lens, dtype="int32") + self.max_seqlen_q = self.max_seq_len_this_time + self.max_seqlen_k = np.max(cache_lens) + + + def get_attntion_meta(self): + """get_attntion_meta""" + return self.attention_metadata + + def get_kv_cache_shape( + self, + max_num_blocks: int, + ): + """ + Caculate kv cache shape + """ + # [total_tokens, kv_num_heads, head_dim] + return (max_num_blocks * self.block_size, + self.kv_num_heads, + self.head_dim) + + @paddle.no_grad() + def forward_mixed( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + layer: Attention, + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + """Run a forward for mixed.""" + token_num = qkv.shape[0] + q_size = self.num_heads * self.head_dim + kv_size = self.kv_num_heads * self.head_dim + num_or_sections = [q_size, kv_size, kv_size] + query, key, value = paddle.split(qkv, num_or_sections=num_or_sections, axis=-1) + + query = query.reshape_((1, -1, self.num_heads, self.head_dim)) + key = key.reshape_((1, -1, self.kv_num_heads, self.head_dim)) + + + # 1. Rope + if self.rotary_embs.dtype != query.dtype: + self.rotary_embs = paddle.cast(self.rotary_embs, query.dtype) + + query, key = fused_rotary_embedding( + query, + key, + self.rotary_embs, + self.position_ids, + layer.use_neox_rotary_style + ) + + # 2. Save kv cache + # shape: [total_tokens, kv_num_heads, head_dim] + key = key.reshape_((-1, self.kv_num_heads, self.head_dim)) + value = value.reshape_((-1, self.kv_num_heads, self.head_dim)) + key_caches = forward_meta.caches[2 * layer.layer_id] + value_caches = forward_meta.caches[2 * layer.layer_id + 1] + key_caches[self.slot_mapping, :, :] = key + value_caches[self.slot_mapping, :, :] = value + + # 3. calc attn + query = query.reshape_((-1, self.num_heads, self.head_dim)) + key_caches = key_caches.reshape((-1, self.block_size, self.kv_num_heads, self.head_dim)) + value_caches = value_caches.reshape((-1, self.block_size, self.kv_num_heads, self.head_dim)) + res = flash_attn_var_len( + query=query, + key=key_caches, + value=value_caches, + cu_seqlens_q=self.cu_query_lens, + cu_seqlens_k=None, + seqused_k=self.seqused_k, + leftpad_k=None, + block_table=self.block_tables, + alibi_slopes=None, + max_seqlen_q=self.max_seqlen_q, + max_seqlen_k=self.max_seqlen_k, + p_dropout=0.0, + softmax_scale=self.scaling, + zero_tensors=False, + is_causal=self.causal, + window_size_left=-1, + window_size_right=-1, + softcap=0.0, + return_softmax=False, + ) + res = res.reshape_((token_num, -1)) + return res + diff --git a/fastdeploy/model_executor/layers/backends/gcu/attention/mem_efficient_attn_backend.py b/fastdeploy/model_executor/layers/backends/gcu/attention/mem_efficient_attn_backend.py new file mode 100644 index 000000000..bc5d8f151 --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/gcu/attention/mem_efficient_attn_backend.py @@ -0,0 +1,357 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional + +import paddle + +import numpy as np +import math + + +if TYPE_CHECKING: + from paddle._typing.dtype_like import _DTypeLiteral + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.layers.attention.attention import Attention +from fastdeploy.model_executor.layers.attention.base_attention_backend import ( + AttentionBackend, AttentionMetadata) +from fastdeploy.worker.forward_meta import ForwardMeta, ForwardMode + +from fastdeploy.model_executor.ops.gcu import (fused_rotary_embedding, + mem_efficient_attention, + flash_attn_var_len) +from paddleformers.utils.log import logger + +@dataclass +class GCUMemEfficientAttnMetadata(AttentionMetadata): + """ + GCUMemEfficientAttnMetadata + """ + forward_mode: ForwardMode = ForwardMode.MIXED + _dtype: _DTypeLiteral = paddle.bfloat16 + + seq_lens_encoder: Optional[paddle.Tensor] = None + seq_lens_decoder: Optional[paddle.Tensor] = None + seq_lens_this_time: Optional[paddle.Tensor] = None + cum_offsets: Optional[paddle.Tensor] = None + padding_offset: Optional[paddle.Tensor] = None + + cu_seqlens_q: Optional[paddle.Tensor] = None + cu_seqlens_k: Optional[paddle.Tensor] = None + caches: Optional[paddle.Tensor] = None + + block_tables: Optional[paddle.Tensor] = None + rotary_embs: Optional[paddle.Tensor] = None + attn_mask: Optional[paddle.Tensor] = None + + pre_caches_length: int = 0 + + + + +class GCUMemEfficientAttnBackend(AttentionBackend): + """ + GCUMemEfficientAttnBackend backend implementation. + """ + + def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, + head_dim: int): + """ + GCUMemEfficientAttnBackend __init__ + """ + super().__init__() + self.attention_metadata: GCUMemEfficientAttnMetadata = None + self.block_size = fd_config.parallel_config.block_size + self.max_seq_len = fd_config.parallel_config.max_model_len + self.max_num_seqs = fd_config.parallel_config.max_num_seqs + + self.causal = getattr(fd_config.model_config, "causal", True) + + self.rank = fd_config.parallel_config.tensor_parallel_rank + self.kv_num_heads = kv_num_heads + self.num_heads = num_heads + self.head_dim = head_dim + self.scaling = 1.0 / (self.head_dim**0.5) + self.num_layers = fd_config.model_config.num_layers + self.position_ids_base = paddle.arange(self.max_seq_len) + + # TODO(zhengjun): Need to adapt the allocation logic and + # temporarily allocate according to fixed size + self.all_block_tables: List[List[int]] = None + self.all_slot_mapping: List[List[int]] = None + + self.rotary_embs = None + self.use_paddle_native_sdpa = False + + + + def init_attention_metadata(self, forward_meta: ForwardMeta): + """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" + metadata = GCUMemEfficientAttnMetadata() + + metadata.forward_mode = forward_meta.forward_mode + + metadata._dtype = paddle.get_default_dtype() + + metadata.seq_lens_encoder = forward_meta.seq_lens_encoder + metadata.seq_lens_decoder = forward_meta.seq_lens_decoder + metadata.seq_lens_this_time = forward_meta.seq_lens_this_time + metadata.cum_offsets = forward_meta.cum_offsets + metadata.padding_offset = forward_meta.padding_offset + + metadata.cu_seqlens_q = forward_meta.cu_seqlens_q + metadata.cu_seqlens_k = forward_meta.cu_seqlens_k + metadata.caches = forward_meta.caches + + # metadata.block_tables = forward_meta.block_tables + metadata.rotary_embs = forward_meta.rotary_embs + metadata.attn_mask = forward_meta.attn_mask # not init + + metadata.pre_caches_length = forward_meta.pre_caches_length # not inited + + + self.attention_metadata = metadata + + if self.rotary_embs is None: + self.rotary_embs = metadata.rotary_embs.reshape((-1, self.head_dim)) + + # some info for attention + self.seq_lens_this_time_list = forward_meta.seq_lens_this_time.tolist() # List[int] + self.seq_lens_encoder_list = forward_meta.seq_lens_encoder.tolist() # List[List[int]] + self.seq_lens_decoder_list = forward_meta.seq_lens_decoder.tolist() # List[List[int]] + self.seq_lens_sum = np.sum(self.seq_lens_this_time_list) + self.max_seq_len_this_time = np.max(self.seq_lens_this_time_list) + + num_seqs = forward_meta.seq_lens_this_time.shape[0] + + + self.is_decoder = all(x[0] == 0 for x in self.seq_lens_encoder_list) + self.is_all_prefill = all(x[0] == 0 for x in self.seq_lens_decoder_list) + + + # block_tables and slot_mapping + if self.all_slot_mapping is None: + max_num_blocks_per_seq = (self.max_seq_len + self.block_size - 1) // self.block_size + total_blocks = max_num_blocks_per_seq * self.max_num_seqs + self.all_block_tables = np.arange(0, total_blocks, dtype=np.int32).reshape((self.max_num_seqs, max_num_blocks_per_seq)).tolist() + self.all_slot_mapping = np.arange(0, total_blocks * self.block_size, dtype=np.int32).reshape((self.max_num_seqs, -1)).tolist() + + block_tables = [] + slot_mapping = [] + cache_slot_range = [] + cache_lens = [] + query_lens = [] + cached_kv_lens = [] + cached_kv_slot_range = [] + position_ids = [] + for seq_idx in range(num_seqs): + cache_len = None + if self.seq_lens_encoder_list[seq_idx][0] != 0: # prefill + cache_len = 0 + elif self.seq_lens_decoder_list[seq_idx][0] != 0: # decode + cache_len = self.seq_lens_decoder_list[seq_idx][0] + # else: doesnot have req in this seq_idx + + if cache_len is not None: + lens_this_time = self.seq_lens_this_time_list[seq_idx] + start = cache_len + end = start + lens_this_time + slot_mapping.extend(self.all_slot_mapping[seq_idx][start:end]) + cache_slot_range.extend(self.all_slot_mapping[seq_idx][0:end]) + cache_lens.append(end) + block_tables.append(self.all_block_tables[seq_idx]) + position_ids.extend(self.position_ids_base[start:end]) + query_lens.append(lens_this_time) + cached_kv_lens.append(end) + cached_kv_slot_range.append([self.all_slot_mapping[seq_idx][0], self.all_slot_mapping[seq_idx][end]]) + + + + self.block_tables = paddle.to_tensor(block_tables, dtype="int32") + self.slot_mapping = paddle.to_tensor(slot_mapping, dtype="int32") + self.cache_slot_range = paddle.to_tensor(cache_slot_range, dtype="int32") + self.position_ids = paddle.to_tensor(position_ids, dtype="int32") + self.position_ids = self.position_ids.reshape_((1, -1)) + + logger.info(f"[FD_DEBUG] init_attention_metadata, self.position_ids:\n{self.position_ids}") + + cu_query_lens_data = [0] + for seq_idx in range(num_seqs): + if self.seq_lens_this_time_list[seq_idx] != 0: + cu_query_lens_data.append(self.seq_lens_this_time_list[seq_idx]) + cu_query_lens = np.array(cu_query_lens_data, dtype=np.int32).cumsum(axis=0) + + self.cu_query_lens = paddle.to_tensor(cu_query_lens, dtype="int32") + self.seqused_k = paddle.to_tensor(cache_lens, dtype="int32") + self.max_seqlen_q = self.max_seq_len_this_time + self.max_seqlen_k = np.max(cache_lens) + + self.query_lens = query_lens + self.cached_kv_lens = cached_kv_lens + self.cached_kv_slot_range = cached_kv_slot_range + + + def get_attntion_meta(self): + """get_attntion_meta""" + return self.attention_metadata + + def get_kv_cache_shape( + self, + max_num_blocks: int, + ): + """ + Caculate kv cache shape + """ + # [total_tokens, kv_num_heads, head_dim] + return (max_num_blocks * self.block_size, + self.kv_num_heads, + self.head_dim) + + @paddle.no_grad() + def forward_mixed( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + layer: Attention, + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + """Run a forward for mixed.""" + token_num = qkv.shape[0] + q_size = self.num_heads * self.head_dim + kv_size = self.kv_num_heads * self.head_dim + num_or_sections = [q_size, kv_size, kv_size] + query, key, value = paddle.split(qkv, num_or_sections=num_or_sections, axis=-1) + + query = query.reshape_((1, -1, self.num_heads, self.head_dim)) + key = key.reshape_((1, -1, self.kv_num_heads, self.head_dim)) + + + # 1. Rope + if self.rotary_embs.dtype != query.dtype: + self.rotary_embs = paddle.cast(self.rotary_embs, query.dtype) + + query, key = fused_rotary_embedding( + query, + key, + self.rotary_embs, + self.position_ids, + layer.use_neox_rotary_style + ) + + # 2. Save kv cache + # shape: [total_tokens, kv_num_heads, head_dim] + key = key.reshape_((-1, self.kv_num_heads, self.head_dim)) + value = value.reshape_((-1, self.kv_num_heads, self.head_dim)) + key_caches = forward_meta.caches[2 * layer.layer_id] + value_caches = forward_meta.caches[2 * layer.layer_id + 1] + key_caches[self.slot_mapping, :, :] = key + value_caches[self.slot_mapping, :, :] = value + + # 3. calc attn + query = query.reshape_((-1, self.num_heads, self.head_dim)) + + q_start = 0 + result = paddle.empty_like(query) + for idx in range(len(self.query_lens)): + q_end = q_start + self.query_lens[idx] + kv_start = self.cached_kv_slot_range[idx][0] + kv_end = self.cached_kv_slot_range[idx][1] + + q_ = query[q_start:q_end, :, :] + k_ = key_caches[kv_start:kv_end, :, :] + v_ = value_caches[kv_start:kv_end, :, :] + + if self.use_paddle_native_sdpa: + res = self.native_sdpa_impl( + q_, k_, v_ + ) + else: + res = mem_efficient_attention( + query=q_.unsqueeze(0), + key=k_.unsqueeze(0), + value=v_.unsqueeze(0), + attn_mask=None, + dropout=0.0, + softmax_scale=self.scaling, + mask_mode=1, + seqlens=[0], + causal=self.causal, + ) + result[q_start:q_end, :, :] = res + q_start = q_end + result = result.reshape_((token_num, -1)) + return result + + + def get_triangle_upper_mask(self, shape, dtype): + # [batch_size, 1, q_seq_len, kv_seq_len] + shape[1] = 1 + q_seq_len = shape[2] + kv_seq_len = shape[3] + paddle_dtype = dtype # paddle.base.data_feeder.convert_dtype(dtype) + mask = paddle.full(shape, paddle.finfo(paddle_dtype).min, dtype=paddle_dtype) + mask = paddle.triu(mask, diagonal=kv_seq_len - q_seq_len + 1) + return mask + + + def native_sdpa_impl(self, query, key, value): + # input shape: [num_tokens, num_heads, head_dim] -> [1, num_tokens, num_heads, head_dim] + q = query.unsqueeze(0) + k = key.unsqueeze(0) + v = value.unsqueeze(0) + batch, q_seq_len, heads, head_dim = q.shape + kv_seq_len = k.shape[1] + + # [batch_size, seq_len, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim] + q = paddle.transpose(q, [0, 2, 1, 3]) + k = paddle.transpose(k, [0, 2, 1, 3]) + v = paddle.transpose(v, [0, 2, 1, 3]) + + # GQA + if q.shape[1] != k.shape[1]: + kv_head = k.shape[1] + + k = k.reshape([batch, kv_head, 1, kv_seq_len, head_dim]) + k = paddle.tile(k, [1, 1, heads // kv_head, 1, 1]) + k = k.reshape([batch, heads, kv_seq_len, head_dim]) + + v = v.reshape([batch, kv_head, 1, kv_seq_len, head_dim]) + v = paddle.tile(v, [1, 1, heads // kv_head, 1, 1]) + v = v.reshape([batch, heads, kv_seq_len, head_dim]) + + # matmul and devide by sqrt(head_dim) + attn_weights = paddle.matmul(q / math.sqrt(head_dim), k.transpose([0, 1, 3, 2])) + + attention_mask = self.get_triangle_upper_mask( + [batch, 1, q_seq_len, kv_seq_len], q.dtype + ) + attn_weights = attn_weights + attention_mask + attn_weights = paddle.nn.functional.softmax( + attn_weights, axis=-1, dtype="float32" + ).astype(q.dtype) + + attn_output = paddle.matmul(attn_weights, v) + attn_output = attn_output.transpose([0, 2, 1, 3]) + return attn_output.squeeze(0) diff --git a/fastdeploy/model_executor/layers/backends/gcu/moe/__init__.py b/fastdeploy/model_executor/layers/backends/gcu/moe/__init__.py new file mode 100644 index 000000000..c61a9d89f --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/gcu/moe/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""" +gcu moe +""" diff --git a/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py b/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py new file mode 100644 index 000000000..0e37430e7 --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py @@ -0,0 +1,402 @@ + +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + + +import multiprocessing +import os + +import numpy as np +import paddle +from paddle import nn +from paddleformers.utils.log import logger + +from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import \ + MoEMethodBase +from fastdeploy.model_executor.layers.utils import (CpuGuard, + create_and_set_parameter, + get_tensor) +from fastdeploy.model_executor.ops.gcu import (invoke_fused_moe_kernel, + moe_align_block_size, + topk_softmax, + weight_quantize_custom_rtn, + weight_quantize_rtn) + + +class GCUFusedMoeMethod(MoEMethodBase): + """ + Use GCU to compute Fused MoE. + """ + def __init__(self, quant_config): + super().__init__(quant_config) + self.group_size = -1 + + + def create_weights(self, layer: nn.Layer, state_dict): + """ + Paddle gcu create weight process. + """ + # bf16 + ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) + stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0) + stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0) + for idx, weight_tensor in enumerate( + [stacked_ffn1_weights, stacked_ffn2_weights]): + # shape [E, K, N] -> [E, N, K] + weight_tensor = paddle.transpose(weight_tensor, [0, 2, 1]) + weight_name = self.added_weight_attrs[idx] + setattr( + layer, weight_name, + layer.create_parameter( + shape=weight_tensor.shape, + dtype=weight_tensor.dtype, + default_initializer=paddle.nn.initializer.Constant(0), + )) + getattr(layer, weight_name).set_value(weight_tensor) + + + @paddle.no_grad() + def compute_ffn( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate_out: paddle.Tensor, + enable_quant = False + ) -> paddle.Tensor: + """ + Paddle gcu compute Fused MoE. + """ + token_num, hidden_size = x.shape + top_k = layer.top_k + moe_intermediate_size = layer.moe_intermediate_size + num_experts = layer.num_local_experts + + topk_weights = paddle.empty([token_num, top_k], dtype=gate_out.dtype) + topk_indices = paddle.empty([token_num, top_k], dtype="int32") + token_expert_indices = paddle.empty([token_num, top_k], dtype="int32",) + topk_softmax(topk_weights, topk_indices, token_expert_indices, gate_out, norm_topk_prob=True) + + config = { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + } + + block_size = config["BLOCK_SIZE_M"] + max_num_tokens_padded = np.prod(topk_indices.shape) + num_experts * (block_size - 1) + max_num_m_blocks = max_num_tokens_padded // block_size + sorted_token_ids = paddle.empty([max_num_tokens_padded], dtype="int32") + expert_ids = paddle.zeros(shape=[max_num_m_blocks], dtype="int32") + num_tokens_post_pad = paddle.empty([1], dtype="int32") + + sorted_token_ids, expert_ids, num_tokens_post_pad = moe_align_block_size( + sorted_token_ids, + expert_ids, + num_tokens_post_pad, + topk_indices, + num_experts, + block_size, + ) + + intermediate_cache1 = paddle.empty( + [token_num, top_k, moe_intermediate_size * 2], + dtype=x.dtype, + ) + + ffn1_B_scale = layer.moe_ffn1_weight_scale if enable_quant else None + ffn1_B_zeros = layer.moe_ffn1_weight_zeros if enable_quant else None + + invoke_fused_moe_kernel( + x, # input + layer.moe_ffn1_weight, # weight + intermediate_cache1, # output + None, # A_scale + ffn1_B_scale, # B_scale + ffn1_B_zeros, # B_zp + topk_weights, + topk_indices, + sorted_token_ids, + expert_ids, + num_tokens_post_pad, + False, # mul_routed_weight + top_k, + config, + enable_quant, # use_int4_w4a16 + [0, self.group_size], # block_shape + ) + + intermediate_cache2 = paddle.empty( + (token_num, top_k, moe_intermediate_size), + dtype=x.dtype, + ) + + intermediate_cache2 = paddle.incubate.nn.functional.swiglu( + intermediate_cache1) + + intermediate_cache2 = intermediate_cache2.reshape([-1, moe_intermediate_size]) + + intermediate_cache3 = paddle.empty( + (token_num, top_k, hidden_size), + dtype=x.dtype, + ) + + ffn2_B_scale = layer.moe_ffn2_weight_scale if enable_quant else None + ffn2_B_zeros = layer.moe_ffn2_weight_zeros if enable_quant else None + + invoke_fused_moe_kernel( + intermediate_cache2, # input + layer.moe_ffn2_weight, # weight + intermediate_cache3, # output + None, # A_scale + ffn2_B_scale, # B_scale + ffn2_B_zeros, # B_zp + topk_weights, + topk_indices, + sorted_token_ids, + expert_ids, + num_tokens_post_pad, + True, # mul_routed_weight + 1, + config, + enable_quant, # use_int4_w4a16 + [0, self.group_size], # block_shape + ) + + intermediate_cache3.reshape_([token_num, top_k, hidden_size]) + fused_moe_out = intermediate_cache3.sum(axis=1) + fused_moe_out = fused_moe_out.reshape_([token_num, hidden_size]) + + if layer.tp_size > 1: + from fastdeploy.distributed.communication_op import \ + tensor_model_parallel_all_reduce + tensor_model_parallel_all_reduce(fused_moe_out) + + return fused_moe_out + + + def apply( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate_out: paddle.Tensor, + ) -> paddle.Tensor: + """ + Paddle gcu compute Fused MoE. + """ + return self.compute_ffn(layer, x, gate_out, enable_quant=False) + + + def apply_ep_prefill( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate_out: paddle.Tensor, + ) -> paddle.Tensor: + """ + Apply the EP prefill method. + """ + raise NotImplementedError + + + def apply_ep_decode( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate_out: paddle.Tensor, + ) -> paddle.Tensor: + """ + Apply the EP decoder method. + """ + raise NotImplementedError + + + def apply_tp( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate_out: paddle.Tensor, + ) -> paddle.Tensor: + """ + Paddle Cutlass compute Fused MoE. + """ + raise NotImplementedError + + +class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod): + """ + weight only for moe + """ + + def __init__(self, quant_config): + super().__init__(quant_config) + self.quant_config = quant_config + self.moe_quant_type = self.quant_config.algo + self.pack_num = 1 + + assert self.quant_config.algo == "weight_only_int4", \ + "GCUWeightOnlyMoEMethod only support weight_only_int4, but got:{self.quant_config.algo}" + + self.added_qzeros_attrs = [ + "moe_ffn1_weight_zeros", "moe_ffn2_weight_zeros" + ] + self.group_size = 64 + + self.quant_multi_process_group_size = int( + os.getenv("FD_MOE_QUANT_MULTI_PROCESS_GROUP_SIZE", 8) + ) + logger.info(f"GCUWeightOnlyMoEMethod quant_multi_process_group_size: {self.quant_multi_process_group_size}") + + + def process_prequanted_weights(self, layer: nn.Layer, state_dict): + """ + Paddle gcu process prequanted weights. + """ + ffn1_expert_weight_key = layer.weight_key_map.get( + "ffn1_expert_weight_key", None) + ffn2_expert_weight_key = layer.weight_key_map.get( + "ffn2_expert_weight_key", None) + ffn1_expert_weight_scale_key = layer.weight_key_map.get( + "ffn1_expert_weight_scale_key", None) + ffn2_expert_weight_scale_key = layer.weight_key_map.get( + "ffn2_expert_weight_scale_key", None) + + ffn1_weights, ffn2_weights = layer.load_experts_weight( + state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key) + # self.check(layer, ffn1_weights, ffn2_weights) + ffn1_weight_scale = [] + ffn2_weight_scale = [] + for i in range(layer.num_experts): + expert_idx = layer.expert_id_offset + i + ffn1_weight_scale.append( + get_tensor( + state_dict.pop( + ffn1_expert_weight_scale_key.format(expert_idx)))) + ffn2_weight_scale.append( + get_tensor( + state_dict.pop( + ffn2_expert_weight_scale_key.format(expert_idx)))) + + ffn1_weight = paddle.stack(ffn1_weights, axis=0) + ffn2_weight = paddle.stack(ffn2_weights, axis=0) + ffn1_weight_scale = paddle.stack(ffn1_weight_scale, axis=0) + ffn2_weight_scale = paddle.stack(ffn2_weight_scale, axis=0) + + name_tensor_map = { + "moe_ffn1_weight": ffn1_weight, + "moe_ffn2_weight": ffn2_weight, + "moe_ffn1_weight_scale": ffn1_weight_scale, + "moe_ffn2_weight_scale": ffn2_weight_scale + } + for name, tensor in name_tensor_map.items(): + create_and_set_parameter(layer, name, tensor) + + + @paddle.no_grad() + def create_weights(self, layer: nn.Layer, state_dict): + """ + Paddle cutlass create weight process. + """ + ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict) + self.check(layer, ffn1_weights, ffn2_weights) + + + def quant_worker(p_group_idx, shared_dict, weights, moe_quant_type, group_size): + with CpuGuard(): + p_group_size = len(weights) + for group_j in range(p_group_size): + # weight shape [K, N] -> [N/2, K] -> [N, K/2] + quant_weight, scale = weight_quantize_custom_rtn( + weights[group_j], + moe_quant_type, + group_size # group_size + ) + shared_dict[p_group_size * p_group_idx + group_j] = ( + quant_weight, scale + ) + + + for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]): + weight_name = self.added_weight_attrs[idx] + scale_name = self.added_scale_attrs[idx] + zeros_name = self.added_qzeros_attrs[idx] + + if self.quant_multi_process_group_size > 0: + process_group_size = self.quant_multi_process_group_size + process_group_num = layer.num_local_experts // process_group_size + grouped_weights_num = process_group_num * process_group_size + remain_weights_start_idx = grouped_weights_num + + weight_list = [None] * grouped_weights_num + weight_scale_list = [None] * grouped_weights_num + + with multiprocessing.Manager() as manager: + shared_dict = manager.dict({}) + processes = [] + + for i in range(process_group_num): + w = [] + for j in range(process_group_size): + w.append(weight_tensor[process_group_size * i + j].to("cpu")) + + p = multiprocessing.Process( + target=quant_worker, + args=(i, shared_dict, w, self.moe_quant_type, self.group_size) + ) + p.start() + processes.append(p) + + for p in processes: + p.join() + + dict_ = dict(shared_dict) + + for k, v in dict_.items(): + weight_list[k] = v[0].to(ffn1_weights[0].place) + weight_scale_list[k] = v[1].to(ffn1_weights[0].place) + else: + remain_weights_start_idx = 0 + + if remain_weights_start_idx < layer.num_local_experts: + for i in range(remain_weights_start_idx, layer.num_local_experts): + # weight shape [K, N] -> [N/2, K] -> [N, K/2] + quant_weight, scale = weight_quantize_rtn( + weight_tensor[i], + self.moe_quant_type, + self.group_size # group_size + ) + weight_list.append(quant_weight) + weight_scale_list.append(scale) + quanted_weight = paddle.stack(weight_list, axis=0) + create_and_set_parameter(layer, weight_name, quanted_weight) + + quanted_weight_scale = paddle.stack(weight_scale_list, axis=0) + create_and_set_parameter(layer, scale_name, quanted_weight_scale) + + quanted_weight_zeros = quanted_weight_scale * 8 + create_and_set_parameter(layer, zeros_name, quanted_weight_zeros) + + + def apply( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate_out: paddle.Tensor, + ) -> paddle.Tensor: + """ + Paddle gcu compute Fused MoE. + """ + return self.compute_ffn(layer, x, gate_out, enable_quant=True) diff --git a/fastdeploy/model_executor/layers/backends/gcu/quantization/__init__.py b/fastdeploy/model_executor/layers/backends/gcu/quantization/__init__.py new file mode 100644 index 000000000..b5870b4dc --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/gcu/quantization/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""" +gcu quantization +""" +from .weight_only import GCUWeightOnlyLinearMethod + +__all__ = [ + "GCUWeightOnlyLinearMethod", +] diff --git a/fastdeploy/model_executor/layers/backends/gcu/quantization/weight_only.py b/fastdeploy/model_executor/layers/backends/gcu/quantization/weight_only.py new file mode 100644 index 000000000..bddfa93f5 --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/gcu/quantization/weight_only.py @@ -0,0 +1,90 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +from fastdeploy.model_executor.layers.quantization.weight_only import ( + WeightOnlyConfig, WeightOnlyLinearMethod) +from fastdeploy.model_executor.layers.utils import get_tensor +from fastdeploy.model_executor.ops.gcu import linear_quant, weight_quantize_rtn + + +class GCUWeightOnlyLinearMethod(WeightOnlyLinearMethod): + """ + Weight only quantization method for linear layer on GCU + """ + + def __init__( + self, + quant_config: WeightOnlyConfig, + ) -> None: + super().__init__(quant_config) + self.quant_config = quant_config + self.group_size = -1 + + + def create_weights(self, layer): + # The scale shape should be equal to the output dim of weight using Per-Channel Quantization. + linear_weight_scale_shape = [layer.linear_weight_shape[1]] + + layer.linear_weight_shape.reverse() + if self.quant_config.name() == "wint4": + layer.linear_weight_shape[0] //= 2 + layer.weight_dtype = "int8" + layer.linear_weight_scale = layer.create_parameter( + shape=linear_weight_scale_shape, + dtype=layer._dtype, + is_bias=False, + ) + + + def process_prequanted_weights(self, layer, state_dict) -> None: + """ + Process pre-quantized weights before applying them to the model + Args: + layer: The layer that owns the weights + quant_weight: The quantized weights + weight_scale: The scale of the quantized weights + """ + quant_weight = get_tensor(state_dict.pop(layer.weight_key)) + weight_scale = get_tensor(state_dict.pop(layer.weight_scale_key)) + layer.linear_weight.set_value(quant_weight) + layer.linear_weight_scale.set_value( + weight_scale.astype(paddle.get_default_dtype())) + + + def process_loaded_weights(self, layer, weight) -> None: + quanted_weight_tensor, weight_scale_tensor = weight_quantize_rtn( + weight, + self.quant_config.algo, + self.group_size, # group_size + ) + + layer.linear_weight.set_value(quanted_weight_tensor) + layer.linear_weight_scale.set_value( + weight_scale_tensor.astype(paddle.get_default_dtype())) + + + @paddle.no_grad() + def apply(self, layer, x): + linear_out = linear_quant( + lhs=x, + rhs=layer.linear_weight, + scale=layer.linear_weight_scale, + bias=None, + group_size=self.group_size, + ) + return linear_out diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 8e7086261..208e076d5 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -58,7 +58,7 @@ class LinearBase(nn.Layer): """ super().__init__() if current_platform.is_cuda() or current_platform.is_xpu( - ) or current_platform.is_iluvatar(): + ) or current_platform.is_iluvatar() or current_platform.is_gcu(): self.forward = self.forward_cuda else: raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 3219afc21..3e0fe1660 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -20,6 +20,7 @@ from paddleformers.utils.log import logger from fastdeploy import envs from fastdeploy.model_executor.layers.utils import get_tensor +from fastdeploy.platforms import current_platform class FusedMoE(nn.Layer): @@ -95,8 +96,13 @@ class FusedMoE(nn.Layer): self.moe_quant_type = moe_quant_config.name() else: # now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future - from .fused_moe_cutlass_backend import CutlassMoEMethod - self.quant_method = CutlassMoEMethod(None) + if current_platform.is_cuda(): + from .fused_moe_cutlass_backend import CutlassMoEMethod + self.quant_method = CutlassMoEMethod(None) + elif current_platform.is_gcu(): + from fastdeploy.model_executor.layers.backends import \ + GCUFusedMoeMethod + self.quant_method = GCUFusedMoeMethod(None) if self.ep_size > 1: self.quant_method.init_ep(self) diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index e4f78a05d..9b16830b6 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -19,9 +19,14 @@ from typing import Callable, Dict, Optional import numpy as np import paddle from paddle import nn -from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm + from fastdeploy.platforms import current_platform +if current_platform.is_gcu(): + from fastdeploy.model_executor.ops.gcu import fused_add_rms_norm, rms_norm +else: + from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm + from fastdeploy.config import FDConfig from .utils import get_tensor @@ -69,7 +74,10 @@ class RMSNorm(nn.Layer): self.weight_key: Optional[str] = f"{prefix}.weight" self.with_weight: bool = self.weight_key is not None self.eps: float = eps - self.norm_func: Callable = fused_rms_norm + if current_platform.is_gcu(): + self.norm_func: Callable = fused_add_rms_norm + else: + self.norm_func: Callable = fused_rms_norm self.linear_bias: Optional[paddle.Tensor] = linear_bias self.quant_scale: Optional[float] = quant_scale self._dtype: str = self._helper.get_default_dtype() @@ -129,19 +137,26 @@ class RMSNorm(nn.Layer): The `residual_output` is the result of applying the normalization and possibly other operations (like linear transformation) on the `residual_input`. """ - norm_out = self.norm_func( - x, - norm_weight=self.ln_weight, - norm_bias=None, - epsilon=self.eps, - begin_norm_axis=self.begin_norm_axis, - bias=self.linear_bias, - residual=residual_input, - quant_scale=-1 if self.quant_scale is None else self.quant_scale, - quant_round_type=self.quant_round_type, - quant_max_bound=self.quant_max_bound, - quant_min_bound=self.quant_min_bound, - ) + if current_platform.is_gcu(): + if residual_input is None: + return rms_norm(x, self.ln_weight, self.eps) + norm_out = self.norm_func( + x, residual_input, self.ln_weight, self.eps + ) + else: + norm_out = self.norm_func( + x, + norm_weight=self.ln_weight, + norm_bias=None, + epsilon=self.eps, + begin_norm_axis=self.begin_norm_axis, + bias=self.linear_bias, + residual=residual_input, + quant_scale=-1 if self.quant_scale is None else self.quant_scale, + quant_round_type=self.quant_round_type, + quant_max_bound=self.quant_max_bound, + quant_min_bound=self.quant_min_bound, + ) if residual_input is not None: return norm_out[0], norm_out[1] else: @@ -193,7 +208,10 @@ class LayerNorm(nn.Layer): self.with_bias: bool = with_bias self.eps: float = eps self.quant_scale: float = quant_scale - self.norm_func: Callable = fused_layer_norm + if current_platform.is_gcu(): + self.norm_func: Callable = paddle.nn.functional.layer_norm + else: + self.norm_func: Callable = fused_layer_norm self.linear_bias: Optional[paddle.Tensor] = linear_bias self._dtype: str = self._helper.get_default_dtype() self._norm_weight_dtype: str = "float32" @@ -279,19 +297,40 @@ class LayerNorm(nn.Layer): else: raise NotImplementedError("Iluvatar does not support yet!") - norm_out = self.norm_func( - x, - norm_weight=self.ln_weight, - norm_bias=self.ln_bias, - epsilon=self.eps, - begin_norm_axis=1, - bias=self.linear_bias, - residual=residual_input, - quant_scale=-1 if self.quant_scale is None else self.quant_scale, - quant_round_type=self.quant_round_type, - quant_max_bound=self.quant_max_bound, - quant_min_bound=self.quant_min_bound, - ) + if current_platform.is_gcu(): + if residual_input is not None: + y = x + residual_input + out = self.norm_func( + x=y, + normalized_shape=y.shape[1:], + weight=self.ln_weight, + bias=self.linear_bias, + epsilon=self.eps, + ) + return out, y + else: + out = self.norm_func( + x=x, + normalized_shape=x.shape[1:], + weight=self.ln_weight, + bias=self.linear_bias, + epsilon=self.eps, + ) + return out + else: + norm_out = self.norm_func( + x, + norm_weight=self.ln_weight, + norm_bias=self.ln_bias, + epsilon=self.eps, + begin_norm_axis=1, + bias=self.linear_bias, + residual=residual_input, + quant_scale=-1 if self.quant_scale is None else self.quant_scale, + quant_round_type=self.quant_round_type, + quant_max_bound=self.quant_max_bound, + quant_min_bound=self.quant_min_bound, + ) if residual_input is not None: return norm_out[0], norm_out[1] else: diff --git a/fastdeploy/model_executor/layers/quantization/weight_only.py b/fastdeploy/model_executor/layers/quantization/weight_only.py index 9e890853b..720061f85 100644 --- a/fastdeploy/model_executor/layers/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/quantization/weight_only.py @@ -66,6 +66,13 @@ class WeightOnlyConfig(QuantConfigBase): return XPUWeightOnlyMoEMethod(self) else: return XPUWeightOnlyLinearMethod(self) + elif current_platform.is_gcu(): + from fastdeploy.model_executor.layers.backends import ( + GCUWeightOnlyLinearMethod, GCUWeightOnlyMoEMethod) + if isinstance(layer, FusedMoE): + return GCUWeightOnlyMoEMethod(self) + else: + return GCUWeightOnlyLinearMethod(self) else: if isinstance(layer, FusedMoE): if layer.use_method == "cutlass": diff --git a/fastdeploy/model_executor/layers/rotary_embedding.py b/fastdeploy/model_executor/layers/rotary_embedding.py index 17c7dffc1..3266d1097 100644 --- a/fastdeploy/model_executor/layers/rotary_embedding.py +++ b/fastdeploy/model_executor/layers/rotary_embedding.py @@ -55,6 +55,10 @@ class ErnieRotaryEmbedding: dtype="float32") emb = paddle.stack([freqs, freqs], axis=-1).reshape( (bsz, max_seq_len, self.rotary_dim)) + elif current_platform.is_gcu(): + # shape: [B, S, D] + rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1) + return rot_emb else: # shape: [B, S, D/2] rot_emb = paddle.zeros( @@ -95,6 +99,10 @@ class QwenRotaryEmbedding: # shape: [B, S, D/2] freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + if current_platform.is_gcu(): + # shape: [B, S, D] + rot_emb = paddle.concat([freqs.cos(), freqs.sin()], axis=-1) + return rot_emb # shape: [B, S, 1, D] emb = paddle.concat([freqs, freqs], axis=-1).reshape( (bsz, max_seq_len, 1, self.rotary_dim)) diff --git a/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py b/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py index 2e21a85bc..f6b512e0c 100644 --- a/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py +++ b/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py @@ -79,6 +79,21 @@ def apply_penalty_multi_scores( min_dec_lens, eos_token_ids, ) + elif current_platform.is_gcu(): + from fastdeploy.model_executor.ops.gcu import \ + get_token_penalty_multi_scores + logits = get_token_penalty_multi_scores( + pre_token_ids, + logits, + repetition_penalties, + frequency_penalties, + presence_penalties, + temperature, + bad_words_token_ids, + step_idx, + min_dec_lens, + eos_token_ids, + ) else: raise NotImplementedError() diff --git a/fastdeploy/model_executor/layers/sample/ops/top_p_sampling.py b/fastdeploy/model_executor/layers/sample/ops/top_p_sampling.py index e8b9a894e..eeebb610b 100644 --- a/fastdeploy/model_executor/layers/sample/ops/top_p_sampling.py +++ b/fastdeploy/model_executor/layers/sample/ops/top_p_sampling.py @@ -19,7 +19,11 @@ from typing import Literal, Optional import paddle from fastdeploy import envs +from fastdeploy.platforms import current_platform +if current_platform.is_gcu(): + from fastdeploy.model_executor.ops.gcu import \ + top_p_sampling as gcu_top_p_sampling def top_p_sampling( x: paddle.Tensor, @@ -46,13 +50,16 @@ def top_p_sampling( ids = rejection_top_p_sampling(x, ps, seed) _ = None else: - _, ids = paddle.tensor.top_p_sampling(x, - ps, - threshold=threshold, - topp_seed=topp_seed, - seed=seed, - k=k, - mode=mode) + if current_platform.is_gcu(): + _, ids = gcu_top_p_sampling(x, ps) + else: + _, ids = paddle.tensor.top_p_sampling(x, + ps, + threshold=threshold, + topp_seed=topp_seed, + seed=seed, + k=k, + mode=mode) return _, ids diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 217776861..3d2553446 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -171,7 +171,7 @@ class Sampler(nn.Layer): """ super().__init__() if current_platform.is_cuda() or current_platform.is_xpu( - ) or current_platform.is_iluvatar(): + ) or current_platform.is_iluvatar() or current_platform.is_gcu(): self.forward = self.forward_cuda else: raise NotImplementedError() diff --git a/fastdeploy/model_executor/ops/__init__.py b/fastdeploy/model_executor/ops/__init__.py index ebd011e95..508e8707a 100644 --- a/fastdeploy/model_executor/ops/__init__.py +++ b/fastdeploy/model_executor/ops/__init__.py @@ -17,5 +17,6 @@ from . import cpu from . import xpu from . import npu from . import iluvatar +from . import gcu -__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar"] +__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar", "gcu"] diff --git a/fastdeploy/model_executor/ops/gcu/__init__.py b/fastdeploy/model_executor/ops/gcu/__init__.py new file mode 100644 index 000000000..04dab4c85 --- /dev/null +++ b/fastdeploy/model_executor/ops/gcu/__init__.py @@ -0,0 +1,116 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" fastdeploy gcu ops """ +from fastdeploy.platforms import current_platform + +from fastdeploy.import_ops import import_custom_ops, rename_imported_op + +PACKAGE = "fastdeploy.model_executor.ops.gcu" + +import_custom_ops(PACKAGE, ".fastdeploy_ops", globals()) + +if current_platform.is_gcu(): + from paddle_custom_device.gcu.ops import (invoke_fused_moe_kernel, # noqa: F401,E402 + moe_align_block_size, top_p_sampling, # noqa: F401 + topk_softmax, # noqa: F401 + weight_quantize_custom_rtn, # noqa: F401 + weight_quantize_rtn) # noqa: F401 + +# ###################### Ops from PaddleCustomDevice #################### +rename_imported_op( + old_name="fused_rotary_embedding_gcu", + new_name="fused_rotary_embedding", + global_ns=globals(), +) + +rename_imported_op( + old_name="reshape_and_cache_gcu", + new_name="reshape_and_cache", + global_ns=globals(), +) + +rename_imported_op( + old_name="paged_attention_gcu", + new_name="paged_attention", + global_ns=globals(), +) + +rename_imported_op( + old_name="mem_efficient_attention_gcu", + new_name="mem_efficient_attention", + global_ns=globals(), +) + +rename_imported_op( + old_name="flash_attn_var_len_gcu", + new_name="flash_attn_var_len", + global_ns=globals(), +) + +rename_imported_op( + old_name="rms_norm_gcu", + new_name="rms_norm", + global_ns=globals(), +) + +rename_imported_op( + old_name="fused_add_rms_norm_op", + new_name="fused_add_rms_norm", + global_ns=globals(), +) + +rename_imported_op( + old_name="linear_quant_gcu", + new_name="linear_quant", + global_ns=globals(), +) + + +# ###################### CPU OPS #################### +rename_imported_op( + old_name="get_padding_offset_gcu", + new_name="get_padding_offset", + global_ns=globals(), +) + +rename_imported_op( + old_name="update_inputs_gcu", + new_name="update_inputs", + global_ns=globals(), +) + +rename_imported_op( + old_name="rebuild_padding_gcu", + new_name="rebuild_padding", + global_ns=globals(), +) + +rename_imported_op( + old_name="get_token_penalty_multi_scores_gcu", + new_name="get_token_penalty_multi_scores", + global_ns=globals(), +) + +rename_imported_op( + old_name="set_stop_value_multi_ends_gcu", + new_name="set_stop_value_multi_ends", + global_ns=globals(), +) + +rename_imported_op( + old_name="set_value_by_flags_and_idx_gcu", + new_name="set_value_by_flags_and_idx", + global_ns=globals(), +) diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 526197f2a..387ee6884 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -24,6 +24,11 @@ if current_platform.is_iluvatar(): from fastdeploy.model_executor.ops.iluvatar import ( get_padding_offset, save_output, set_stop_value_multi_ends, step_paddle, update_inputs) +elif current_platform.is_gcu(): + from fastdeploy.model_executor.ops.gcu import (get_padding_offset, + save_output, + set_stop_value_multi_ends, + update_inputs) else: from fastdeploy.model_executor.ops.gpu import ( get_padding_offset, save_output, set_stop_value_multi_ends, @@ -391,6 +396,17 @@ def rebuild_padding(tmp_out: paddle.Tensor, output_padding_offset, max_input_length, ) + elif current_platform.is_gcu(): + from fastdeploy.model_executor.ops.gcu import rebuild_padding + hidden_states = rebuild_padding( + tmp_out, + cum_offsets, + seq_len_this_time, + seq_lens_decoder, + seq_lens_encoder, + output_padding_offset, + max_input_length, + ) elif current_platform.is_cpu(): from fastdeploy.model_executor.ops.cpu import rebuild_padding_cpu hidden_states = rebuild_padding_cpu( diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 136197f9c..ad7db57a6 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -124,6 +124,8 @@ class TokenProcessor(object): from fastdeploy.model_executor.ops.xpu import get_output elif current_platform.is_iluvatar(): from fastdeploy.model_executor.ops.iluvatar import get_output + elif current_platform.is_gcu(): + from fastdeploy.model_executor.ops.gcu import get_output else: from fastdeploy.model_executor.ops.gpu import (get_output, get_output_ep, diff --git a/fastdeploy/platforms/__init__.py b/fastdeploy/platforms/__init__.py index 5fbbc0d89..cdead0141 100644 --- a/fastdeploy/platforms/__init__.py +++ b/fastdeploy/platforms/__init__.py @@ -22,6 +22,7 @@ from .xpu import XPUPlatform from .npu import NPUPlatform from .dcu import DCUPlatform from .iluvatar import IluvatarPlatform +from .gcu import GCUPlatform from .base import _Backend # noqa: F401 _current_platform = None @@ -42,6 +43,8 @@ def __getattr__(name: str): _current_platform = DCUPlatform() elif paddle.is_compiled_with_custom_device("iluvatar_gpu"): _current_platform = IluvatarPlatform() + elif paddle.is_compiled_with_custom_device("gcu"): + _current_platform = GCUPlatform() else: _current_platform = CPUPlatform() return _current_platform diff --git a/fastdeploy/platforms/base.py b/fastdeploy/platforms/base.py index 6d93893fa..543f9284f 100644 --- a/fastdeploy/platforms/base.py +++ b/fastdeploy/platforms/base.py @@ -69,6 +69,12 @@ class Platform: """ return paddle.is_compiled_with_custom_device("iluvatar_gpu") + def is_gcu(self) -> bool: + """ + whether platform is gcu + """ + return paddle.is_compiled_with_custom_device("gcu") + @classmethod def get_attention_backend_cls(self, selected_backend): """Get the attention backend""" diff --git a/fastdeploy/platforms/gcu.py b/fastdeploy/platforms/gcu.py new file mode 100644 index 000000000..42b55d641 --- /dev/null +++ b/fastdeploy/platforms/gcu.py @@ -0,0 +1,61 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +from fastdeploy.utils import console_logger as logger + +from .base import Platform, _Backend + + +class GCUPlatform(Platform): + """ + gcu platform class + """ + device_name = "gcu" + + @classmethod + def available(self): + """ + Check whether GCU is available. + """ + try: + assert paddle.base.core.get_custom_device_count('gcu') > 0 + return True + except Exception as e: + logger.warning( + "You are using GCUPlatform, but there is no GCU " + "detected on your machine. Maybe GCU devices is not set properly." + f"\n Original Error is {e}" + ) + return False + + @classmethod + def get_attention_backend_cls(cls, selected_backend: _Backend): + """ + get_attention_backend_cls + """ + if selected_backend == _Backend.NATIVE_ATTN: + logger.info("Using GCU mem_efficient ATTN backend.") + return ("fastdeploy.model_executor.layers.backends.gcu.attention.mem_efficient_attn_backend.GCUMemEfficientAttnBackend") + elif selected_backend == _Backend.APPEND_ATTN: + logger.info("Using GCU ATTN backend.") + return ("fastdeploy.model_executor.layers.backends.gcu.attention.flash_attn_backend.GCUFlashAttnBackend") + else: + raise ValueError( + "Invalid attention backend you specified.\n" + "Now only support [NATIVE_ATTN, APPEND_ATTN] in gcu place." + ) diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py new file mode 100644 index 000000000..5dd8cef1b --- /dev/null +++ b/fastdeploy/worker/gcu_model_runner.py @@ -0,0 +1,1186 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os +import time +from typing import List, Optional + +import numpy as np +import paddle +import paddle.nn as nn +from paddleformers.utils.log import logger + +from fastdeploy.config import FDConfig +from fastdeploy.engine.request import Request +from fastdeploy.model_executor.guided_decoding import get_guided_backend +from fastdeploy.model_executor.guided_decoding.base_guided_decoding import \ + LogitsProcessorBase +from fastdeploy.model_executor.layers.attention import get_attention_backend +from fastdeploy.model_executor.layers.attention.base_attention_backend import \ + AttentionBackend +from fastdeploy.model_executor.layers.rotary_embedding import get_rope +from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata +from fastdeploy.model_executor.layers.sample.sampler import ( + Sampler, SpeculativeSampler) +from fastdeploy.model_executor.model_loader import get_model_from_loader +from fastdeploy.model_executor.ops.gcu import set_value_by_flags_and_idx +from fastdeploy.model_executor.pre_and_post_process import (post_process, + pre_process, + rebuild_padding) +from fastdeploy.worker.forward_meta import ForwardMeta +from fastdeploy.worker.model_runner_base import ModelRunnerBase +from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput + + +class GCUModelRunner(ModelRunnerBase): + """ """ + + def __init__( + self, + fd_config: FDConfig, + device: str, # logic device + device_id: int, # physical device id + rank: int, + local_rank: int): + super().__init__(fd_config=fd_config, device=device) + self.rank = rank + self.local_rank = local_rank + self.device_id = device_id + self.speculative_method = self.fd_config.speculative_config.method + self.speculative_decoding = self.speculative_method is not None + + self.guided_backend = None + if self.fd_config.parallel_config.guided_decoding_backend != "off": + self.guided_backend = get_guided_backend(fd_config=self.fd_config) + + # Sampler + if not self.speculative_decoding: + self.sampler = Sampler() + else: + self.sampler = SpeculativeSampler(fd_config) + + # Cuda Graph + self.use_cudagraph = self.graph_opt_config.use_cudagraph + self.cudagraph_capture_sizes = list( + reversed(self.graph_opt_config.cudagraph_capture_sizes)) + self.cudagraph_num_of_warmups = self.graph_opt_config.cudagraph_num_of_warmups + self.input_ids = paddle.zeros(self.parallel_config.max_num_seqs, + dtype='int32') + + # Initialize share inputs + self._init_share_inputs(self.parallel_config.max_num_seqs) + self.infer_seed_increment = paddle.full( + shape=[self.parallel_config.max_num_seqs, 1], + fill_value=4, + dtype="int64") + self.restore_chunked_prefill_request = dict() + + # Initialize attention Backend + self.attn_backends: list[AttentionBackend] = [] + # self.attn_metadatas: list[AttentionMetadata] = [] + self.initialize_attn_backend() + + # Forward meta store the global meta information of the forward + self.forward_meta: ForwardMeta = None + + # Postprocess Env params + os.environ["INFERENCE_MSG_QUEUE_ID"] = str( + self.local_rank + + int(self.parallel_config.engine_worker_queue_port)) + + def prefill_finished(self): + """ + check whether prefill stage finished + """ + if int(paddle.max(self.share_inputs['seq_lens_encoder'])) != 0: + return 1 + else: + return 0 + + def init_speculative_proposer(self): + """ + Init speculative proposer + """ + if self.speculative_method == "ngram": + raise NotImplementedError( + "NgramProposer is not support by GCUModelRunner." + ) + elif self.speculative_method == "mtp": + raise NotImplementedError( + "MTPProposer is not support by GCUModelRunner." + ) + else: + self.proposer = None + + def _init_logits_processor(self, request): + """ + init logits processor for guided decoding + """ + assert self.guided_backend is not None, "guided_backend is None, use "\ + "--guided-decoding-backend to specify the backend at server startup." + + if request.guided_json is not None: + schemata_key = ("json", request.guided_json) + elif request.guided_regex is not None: + schemata_key = ("regex", request.guided_regex) + elif request.guided_grammar is not None: + schemata_key = ("grammar", request.guided_grammar) + elif request.structural_tag is not None: + schemata_key = ("structural_tag", request.structural_tag) + + return self.guided_backend.get_logits_processor( + schemata_key=schemata_key), schemata_key + + def insert_prefill_inputs(self, req_dicts: List[Request]): + """ + Process inputs for prefill tasks and insert it to share_inputs buffer + """ + if "caches" not in self.share_inputs: + self.initialize_kv_cache() + + if req_dicts[-1].disaggregate_info is not None and req_dicts[ + -1].disaggregate_info["role"] == "prefill": + os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1" + + req_len = len(req_dicts) + for i in range(req_len): + request = req_dicts[i] + idx = request.idx + length = len(request.prompt_token_ids) + + prefill_tokens = [] + if (request.guided_json is not None + or request.guided_regex is not None + or request.structural_tag is not None + or request.guided_grammar is not None): + logits_info, schemata_key = self._init_logits_processor( + request) + request.logits_processor, request.logits_cached = logits_info + request.schemata_key = schemata_key + + # Is Decode Node + if req_dicts[i].disaggregate_info is not None and req_dicts[ + i].disaggregate_info["role"] == "decode": + prefill_tokens.append(request.prompt_token_ids[0]) + self.share_inputs["pre_ids"][idx:idx + + 1] = request.prompt_token_ids[-1] + self.share_inputs["input_ids"][idx:idx + 1, + 0] = request.prompt_token_ids[0] + self.share_inputs['seq_lens_encoder'][idx:idx + 1] = 0 + self.share_inputs['seq_lens_decoder'][idx:idx + 1] = length + self.share_inputs['seq_lens_this_time'][idx:idx + 1] = 1 + self.share_inputs['step_seq_lens_encoder'][idx:idx + 1] = 0 + self.share_inputs['step_seq_lens_decoder'][idx:idx + + 1] = length + self.share_inputs['step_idx'][idx:idx + 1] = 1 + + if self.speculative_decoding: + num_prefill_send_token = self.speculative_config.num_speculative_tokens + 1 + self.share_inputs['draft_tokens'][idx:idx + 1, 0:num_prefill_send_token] =\ + paddle.to_tensor(request.draft_token_ids[0:num_prefill_send_token], dtype="int64") + self.share_inputs['seq_lens_this_time'][ + idx:idx + 1] = num_prefill_send_token + else: + self.share_inputs["pre_ids"][idx:idx + 1] = -1 + self.share_inputs["step_idx"][idx:idx + 1] = 0 + self.share_inputs["input_ids"][idx:idx + + 1, :length] = np.array( + request.prompt_token_ids) + + # Use chunked prefill + if self.parallel_config.enable_chunked_prefill: + request.set("chunk_idx", 1) + logger.info( + f"prefill_chunk_info: {request.prefill_chunk_info}") + token_chunk_size = request.prefill_chunk_info[0] + self.share_inputs["seq_lens_this_time"][ + idx:idx + 1] = token_chunk_size + self.share_inputs['input_ids'][ + idx, :token_chunk_size] = np.array( + request.prompt_token_ids[:token_chunk_size]) + self.share_inputs['step_seq_lens_encoder'][ + idx:idx + 1] = token_chunk_size + self.share_inputs['seq_lens_encoder'][idx:idx + + 1] = token_chunk_size + self.share_inputs['seq_lens_decoder'][ + idx:idx + 1] = request.get("seq_lens_decoder", 0) + self.share_inputs['step_seq_lens_decoder'][ + idx:idx + 1] = request.get("seq_lens_decoder", 0) + else: + self.share_inputs['seq_lens_decoder'][ + idx:idx + 1] = request.get("seq_lens_decoder", 0) + self.share_inputs['step_seq_lens_decoder'][ + idx:idx + 1] = request.get("seq_lens_decoder", 0) + self.share_inputs['seq_lens_this_time'][idx:idx + + 1] = length + self.share_inputs['step_seq_lens_encoder'][idx:idx + + 1] = length + self.share_inputs['seq_lens_encoder'][idx:idx + 1] = length + + if len(request.eos_token_ids + ) < self.parallel_config.eos_tokens_lens: + request.eos_token_ids.append(request.eos_token_ids[0]) + self.share_inputs["eos_token_id"][:] = np.array( + request.eos_token_ids, dtype="int64").reshape(-1, 1) + + self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7) + self.share_inputs["temperature"][idx:idx + 1] = request.get( + "temperature", 0.95) + self.share_inputs["penalty_score"][idx:idx + 1] = request.get( + "repetition_penalty", 1.0) + self.share_inputs["frequency_score"][idx:idx + 1] = request.get( + "frequency_penalty", 0.0) + self.share_inputs["presence_score"][idx:idx + 1] = request.get( + "presence_penalty", 0.0) + + self.share_inputs["min_dec_len"][idx:idx + 1] = request.get( + "min_tokens", 1) + self.share_inputs["max_dec_len"][idx:idx + 1] = request.get( + "max_tokens", self.model_config.max_length) + self.share_inputs["stop_flags"][idx:idx + 1] = False + + self.share_inputs["first_token_ids"][ + idx:idx + 1] = self.share_inputs["input_ids"][idx:idx + 1, :1] + self.share_inputs["ori_seq_lens_encoder"][idx:idx + 1] = length + + if request.get("seed") is not None: + self.share_inputs["infer_seed"][idx:idx + + 1] = request.get("seed") + encoder_block_num = len(request.get("block_tables")) + self.share_inputs["encoder_block_lens"][idx:idx + + 1] = encoder_block_num + self.share_inputs["block_tables"][idx:idx + 1, :] = -1 + self.share_inputs["block_tables"][ + idx:idx + 1, :encoder_block_num] = np.array( + request.block_tables, dtype="int32") + + if request.get("stop_token_ids") is not None and request.get( + "stop_seqs_len") is not None: + stop_seqs_num = len(request.get("stop_seqs_len")) + for i in range(stop_seqs_num, + self.model_config.max_stop_seqs_num): + request.stop_seqs_len.append(0) + self.share_inputs["stop_seqs_len"][:] = np.array( + request.stop_seqs_len, dtype="int32") + self.share_inputs["stop_seqs"][:stop_seqs_num, :len( + request.get("stop_token_ids")[0])] = np.array( + request.get("stop_token_ids"), dtype="int64") + + self.sampler.apply_logits_processor( + idx, request.get("logits_processor"), prefill_tokens) + + self.share_inputs["not_need_stop"][0] = True + + if self.speculative_method in ["mtp"]: + self.proposer.insert_prefill_inputs(req_dicts) + + def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, + expected_decode_len: int): + """ Set dummy prefill inputs to share_inputs """ + max_dec_len = expected_decode_len + 1 + full_length = min(num_tokens // batch_size, + self.parallel_config.max_model_len - max_dec_len) + input_length = int(full_length * self.parallel_config.kv_cache_ratio) + block_num = ( + input_length + self.parallel_config.block_size - 1 + ) // self.parallel_config.block_size + self.parallel_config.enc_dec_block_num + + for i in range(batch_size): + idx = i + self.share_inputs["input_ids"][idx:idx + + 1, :input_length] = np.array( + [5] * input_length) + self.share_inputs["eos_token_id"][:] = np.array( + [2], dtype="int64").reshape(-1, 1) + self.share_inputs["seq_lens_this_time"][idx:idx + 1] = input_length + self.share_inputs["step_seq_lens_encoder"][idx:idx + + 1] = input_length + self.share_inputs["seq_lens_encoder"][idx:idx + 1] = input_length + self.share_inputs["seq_lens_decoder"][idx:idx + 1] = 0 + self.share_inputs["step_idx"][idx:idx + 1] = 0 + self.share_inputs["max_dec_len"][idx:idx + 1] = max_dec_len + self.share_inputs["stop_flags"][idx:idx + 1] = False + + self.share_inputs["first_token_ids"][ + idx:idx + 1] = self.share_inputs["input_ids"][idx:idx + 1, :1] + self.share_inputs["ori_seq_lens_encoder"][idx:idx + + 1] = input_length + + self.share_inputs["encoder_block_lens"][idx:idx + 1] = block_num + self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(idx * block_num, \ + (idx + 1) * block_num, 1) + + def _init_share_inputs(self, max_num_seqs: int): + """Initialize all share buffers for model inputs. + Note: In the future, we may abandon share buffers. + """ + self.MAX_INFER_SEED = 9223372036854775806 + self.share_inputs = {} + + self.share_inputs["pre_ids"] = paddle.full( + [max_num_seqs, self.parallel_config.max_model_len], + -1, + dtype='int64') + self.share_inputs["input_ids"] = paddle.full( + [max_num_seqs, self.parallel_config.max_model_len], + self.parallel_config.pad_token_id, + dtype='int64') + self.share_inputs["eos_token_id"] = paddle.full( + [self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64') + self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], + self.model_config.top_p, + dtype='float32') + self.share_inputs["temperature"] = paddle.full( + [max_num_seqs, 1], self.model_config.temperature, dtype='float32') + self.share_inputs["penalty_score"] = paddle.full( + [max_num_seqs, 1], + self.model_config.penalty_score, + dtype='float32') + self.share_inputs["frequency_score"] = paddle.full( + [max_num_seqs, 1], + self.model_config.frequency_score, + dtype='float32') + self.share_inputs["presence_score"] = paddle.full( + [max_num_seqs, 1], + self.model_config.presence_score, + dtype='float32') + + self.share_inputs["min_dec_len"] = paddle.full( + [max_num_seqs, 1], self.model_config.min_length, dtype='int64') + self.share_inputs["max_dec_len"] = paddle.full( + [max_num_seqs, 1], self.model_config.max_length, dtype='int64') + self.share_inputs["min_length"] = paddle.full( + [max_num_seqs, 1], self.model_config.min_length, dtype='int64') + self.share_inputs["max_length"] = paddle.full( + [max_num_seqs, 1], self.model_config.max_length, dtype='int64') + self.share_inputs["seq_lens_this_time"] = paddle.full(max_num_seqs, + 0, + dtype='int32') + self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int32') + self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int32') + self.share_inputs["step_seq_lens_encoder"] = paddle.full( + [max_num_seqs, 1], 0, dtype='int32') + self.share_inputs["step_seq_lens_decoder"] = paddle.full( + [max_num_seqs, 1], 0, dtype='int32') + self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int64') + self.share_inputs["not_need_stop"] = paddle.full( + [1], False, + dtype='bool').cpu() + self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], + True, + dtype='bool') + self.share_inputs["stop_nums"] = paddle.full([1], + max_num_seqs, + dtype='int64') + + self.share_inputs["bad_tokens"] = paddle.full([1], -1, dtype='int64') + self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], + -1, + dtype='int64') + self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], + False, + dtype='bool') + self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], + 0, + dtype='int32') + self.share_inputs["step_block_list"] = paddle.full([max_num_seqs], + -1, + dtype='int32') + self.share_inputs["step_lens"] = paddle.full([1], 0, dtype='int32') + self.share_inputs["recover_block_list"] = paddle.full([max_num_seqs], + -1, + dtype='int32') + self.share_inputs["recover_lens"] = paddle.full([1], 0, dtype='int32') + self.share_inputs["need_block_list"] = paddle.full([max_num_seqs], + -1, + dtype='int32') + self.share_inputs["need_block_len"] = paddle.full([1], + 0, + dtype='int32') + self.share_inputs["used_list_len"] = paddle.full([max_num_seqs], + 0, + dtype='int32') + self.share_inputs["infer_seed"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int64') + self.share_inputs["first_token_ids"] = paddle.full([max_num_seqs, 1], + -1, + dtype='int64') + self.share_inputs["ori_seq_lens_encoder"] = paddle.full( + [max_num_seqs, 1], 0, dtype='int32') + self.share_inputs["system_lens"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int32') + self.share_inputs["system_ids"] = paddle.full([max_num_seqs, 1], + -1, + dtype='int32') + + self.share_inputs["ids_remove_padding"] = paddle.full( + [max_num_seqs * self.parallel_config.max_model_len], + 0, + dtype='int64') + self.share_inputs["cum_offsets"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int32') + self.share_inputs["padding_offset"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int32') + self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int32') + self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int32') + # AttentionBackend buffers + self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], + 0, + dtype='int32') + self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full( + [max_num_seqs, 1], 0, dtype='int32') + + # Initialize rotary position embedding + tmp_position_ids = paddle.arange( + self.parallel_config.max_model_len).reshape((1, -1)) + self.share_inputs["rope_emb"] = get_rope( + rotary_dim=self.model_config.head_dim, + position_ids=tmp_position_ids, + base=self.model_config.rope_theta, + model_config=self.model_config) + + # Set block tables + pre_max_block_num = ( + self.parallel_config.max_model_len + + self.parallel_config.block_size - 1 + ) // self.parallel_config.block_size + self.parallel_config.enc_dec_block_num + self.share_inputs["block_tables"] = paddle.full( + [max_num_seqs, pre_max_block_num], -1, dtype='int32') + + # Initialize free list + free_list = list( + range( + self.parallel_config.max_block_num - 1, + int(self.parallel_config.max_block_num * + self.parallel_config.kv_cache_ratio) - 1, -1)) + self.free_list_len = len(free_list) + self.share_inputs["free_list"] = paddle.to_tensor(free_list, + dtype="int32") + self.share_inputs["free_list_len"] = paddle.full([1], + self.free_list_len, + dtype="int32") + + # Initialize stop seqs + self.share_inputs["stop_seqs_len"] = paddle.full( + [self.model_config.max_stop_seqs_num], 0, dtype="int32") + self.share_inputs["stop_seqs"] = paddle.full([ + self.model_config.max_stop_seqs_num, + self.model_config.stop_seqs_max_len + ], + -1, + dtype="int32") + if self.speculative_decoding: + max_draft_token_num = self.speculative_config.num_speculative_tokens + self.share_inputs["input_ids_cpu"] = paddle.full( + shape=[max_num_seqs, self.parallel_config.max_model_len], + fill_value=1, + dtype='int64').cpu() + self.share_inputs['accept_tokens'] = paddle.full( + shape=[max_num_seqs, max_draft_token_num + 1], + fill_value=0, + dtype="int64") + self.share_inputs['accept_num'] = paddle.full(shape=[max_num_seqs], + fill_value=0, + dtype='int32') + self.share_inputs['draft_tokens'] = paddle.full( + shape=[max_num_seqs, max_draft_token_num + 1], + fill_value=0, + dtype="int64") + + self.share_inputs['actual_draft_token_num'] = paddle.full( + shape=[max_num_seqs], + fill_value=max_draft_token_num, + dtype="int32") + self.share_inputs["output_cum_offsets"] = paddle.full( + shape=[max_num_seqs, 1], fill_value=0, dtype='int32') + self.share_inputs["output_padding_offset"] = paddle.full( + shape=[max_num_seqs * (max_draft_token_num + 1)], + fill_value=0, + dtype="int32") + + def _prepare_inputs(self) -> None: + """ prepare the model inputs """ + # Remove padding + ( + ids_remove_padding, + cum_offsets, + padding_offset, + cu_seqlens_q, + cu_seqlens_k, + output_cum_offsets, + output_padding_offset, + ) = pre_process( + self.parallel_config.max_model_len, self.share_inputs["input_ids"], + self.share_inputs["seq_lens_this_time"], self.speculative_decoding, + self.share_inputs["draft_tokens"] if self.speculative_decoding else + None, self.share_inputs["seq_lens_encoder"], + self.share_inputs["seq_lens_decoder"]) + + self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, + False) + self.share_inputs["cum_offsets"].copy_(cum_offsets, False) + self.share_inputs["padding_offset"].copy_(padding_offset, False) + self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) + self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) + + # For speculative decoding + if self.speculative_decoding: + self.share_inputs["output_cum_offsets"].copy_( + output_cum_offsets, False) + self.share_inputs["output_padding_offset"].copy_( + output_padding_offset, False) + + # Initialize forward meta data + self.initialize_forward_meta() + + # Get sampling metadata + self.sampling_metadata = SamplingMetadata( + temperature=self.share_inputs["temperature"], + top_p=self.share_inputs["top_p"], + step_idx=self.share_inputs["step_idx"], + pre_token_ids=self.share_inputs["pre_ids"], + frequency_penalties=self.share_inputs["frequency_score"], + presence_penalties=self.share_inputs["presence_score"], + repetition_penalties=self.share_inputs["penalty_score"], + min_dec_lens=self.share_inputs["min_dec_len"], + bad_words_token_ids=self.share_inputs["bad_tokens"], + eos_token_ids=self.share_inputs["eos_token_id"], + ) + + def load_model(self) -> None: + """ load or download model """ + logger.info( + f"Starting to load model {self.model_config.architectures[0]}") + time_before_load = time.perf_counter() + # 1. Load original model + self.model = get_model_from_loader(fd_config=self.fd_config) + # 1.1 Load RL dynamic model + if self.fd_config.load_config.dynamic_load_weight: + from fastdeploy.rl.dynamic_weight_manager import \ + DynamicWeightManager + self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model) + + # 2. Load lora model + + # 3. Load drafter model(for speculative decoding) + + time_after_load = time.perf_counter() + logger.info( + f"Model loading took {time_after_load - time_before_load} seconds") + + # 4. Init proposer for speculative method + self.init_speculative_proposer() + + def get_model(self) -> nn.Layer: + """ get current model """ + return self.model + + def initialize_forward_meta(self): + """ + Initialize forward meta and attention meta data + """ + # Initialize forward meta + self.forward_meta = ForwardMeta.init_forward_meta( + self.share_inputs, self.attn_backends[0]) + + # Initialzie attention meta data + for attn_backend in self.attn_backends: + attn_backend.init_attention_metadata(self.forward_meta) + + def clear_cache(self): + """Clear cached data from shared inputs and forward metadata.""" + self.share_inputs.pop("caches", None) + if self.forward_meta is not None: + self.forward_meta.clear_caches() + + def clear_parameters(self, pid): + """"dynamic model loader use to clear parameters use for RL""" + self.dynamic_weight_manager.clear_parameters(pid) + self.clear_cache() + self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory") + + def update_parameters(self, pid): + """"dynamic model loader use to update parameters use for RL""" + self.dynamic_weight_manager.update_parameters(pid) + self.initialize_kv_cache() + self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") + + def initialize_kv_cache(self) -> None: + """ + Initialize kv cache + """ + cache_kvs = {} + max_block_num = self.num_gcu_blocks + + # Get kv cache dtype + cache_type = self.parallel_config.dtype + + if (self.quant_config + and hasattr(self.quant_config, "kv_cache_quant_type") + and self.quant_config.kv_cache_quant_type is not None): + cache_type = 'uint8' + + # Get kv cache shape + kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( + max_num_blocks=max_block_num) + # local_rank = self.local_rank % self.parallel_config.tensor_parallel_degree + + if not self.parallel_config.do_profile and ( + self.parallel_config.enable_prefix_caching \ + or self.parallel_config.splitwise_role != "mixed"): + raise NotImplementedError( + "prefix_caching is not support by GCUModelRunner." + ) + else: + for i in range(self.model_config.num_layers): + + cache_kvs["key_caches_{}".format(i)] = paddle.full( + shape=kv_cache_shape, + fill_value=0, + dtype=cache_type, + ) + cache_kvs["value_caches_{}".format(i)] = paddle.full( + shape=kv_cache_shape, + fill_value=0, + dtype=cache_type, + ) + self.share_inputs["caches"] = list(cache_kvs.values()) + for value in cache_kvs.values(): + del value + + def initialize_attn_backend(self) -> None: + """ + Initialize attention backends and forward metadata + """ + assert len(self.attn_backends) == 0 + + num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_degree + self.model_config.kv_num_heads = int( + self.model_config.num_key_value_heads + ) // self.parallel_config.tensor_parallel_degree + head_dim = self.model_config.head_dim + + # Get the attention backend + attn_cls = get_attention_backend() + attn_backend = attn_cls(self.fd_config, + kv_num_heads=self.model_config.kv_num_heads, + num_heads=num_heads, + head_dim=head_dim) + if attn_backend is None: + raise NotImplementedError( + "Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly." + ) + self.attn_backends.append(attn_backend) + + def _dummy_run(self, + num_tokens: paddle.Tensor, + batch_size: paddle.Tensor, + expected_decode_len: int = 1, + in_capturing: bool = False) -> paddle.Tensor: + """ + Use dummy inputs to run before formal execution. + Args: + num_tokens: + expected_decode_len: Expected number of tokens generated + """ + self._dummy_prefill_inputs(num_tokens=num_tokens, + batch_size=batch_size, + expected_decode_len=expected_decode_len) + if self.speculative_method in ["mtp"]: + self.proposer.dummy_prefill_inputs( + num_tokens=num_tokens, + batch_size=batch_size, + expected_decode_len=expected_decode_len) + while True: + + # 1. Compute real num_tokens + self._prepare_inputs() + + # 2. Initialize attention backend and forward meta data + + # 3. Prepare lora + + # 4. Run model + is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] + > 1).sum() > 0) + self.forward_meta.step_use_cudagraph = is_decode_batch and in_capturing + self.forward_meta.is_decode_batch = is_decode_batch + model_output = self.model( + ids_remove_padding=self.share_inputs["ids_remove_padding"], + forward_meta=self.forward_meta) + + hiddden_states = rebuild_padding( + model_output, + self.share_inputs["cum_offsets"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["seq_lens_encoder"], + self.share_inputs["output_padding_offset"] + if self.speculative_decoding else + None, # speculative decoding requires + self.parallel_config.max_model_len, + ) + + # 5. Execute spec decode + logits = self.model.compute_logits(hiddden_states) + + if not self.speculative_decoding: + set_value_by_flags_and_idx( + self.share_inputs["pre_ids"], + self.share_inputs["input_ids"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_encoder"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["step_idx"], + self.share_inputs["stop_flags"], + ) + sampled_token_ids = self.sampler(logits, + self.sampling_metadata) + if self.parallel_config.tensor_parallel_degree > 1: + paddle.distributed.broadcast(sampled_token_ids, 0) + else: + self.sampler(logits, self.sampling_metadata, + self.parallel_config.max_model_len, + self.share_inputs) + sampled_token_ids = None + if self.parallel_config.tensor_parallel_degree > 1: + paddle.distributed.broadcast( + self.share_inputs["accept_tokens"], 0) + paddle.distributed.broadcast( + self.share_inputs["accept_num"], 0) + paddle.distributed.broadcast(self.share_inputs["step_idx"], + 0) + paddle.distributed.broadcast( + self.share_inputs["stop_flags"], 0) + + # 6. post process + model_output_data = ModelOutputData( + next_tokens=self.share_inputs["next_tokens"], + stop_flags=self.share_inputs["stop_flags"], + step_idx=self.share_inputs["step_idx"], + max_dec_len=self.share_inputs["max_dec_len"], + pre_ids=self.share_inputs["pre_ids"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + eos_token_id=self.share_inputs["eos_token_id"], + not_need_stop=self.share_inputs["not_need_stop"], + input_ids=self.share_inputs["input_ids"], + stop_nums=self.share_inputs["stop_nums"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + is_block_step=self.share_inputs["is_block_step"], + full_hidden_states=model_output, + msg_queue_id=self.parallel_config.msg_queue_id, + mp_rank=self.local_rank, + use_ep=self.parallel_config.use_ep, + draft_tokens=self.share_inputs["draft_tokens"] + if self.speculative_decoding else None, + actual_draft_token_num=self. + share_inputs["actual_draft_token_num"] + if self.speculative_decoding else None, + accept_tokens=self.share_inputs["accept_tokens"] + if self.speculative_decoding else None, + accept_num=self.share_inputs["accept_num"] + if self.speculative_decoding else None) + + post_process(sampled_token_ids=sampled_token_ids, + model_output=model_output_data, + speculative_decoding=self.speculative_decoding, + skip_save_output=True) + + if self.speculative_decoding: + if self.speculative_method == "mtp": + self.proposer.run(full_hidden_states=model_output) + else: + self.proposer.run(share_inputs=self.share_inputs) + + # 7. Updata 'infer_seed' and step_cuda() + self.share_inputs["infer_seed"].add_(self.infer_seed_increment) + self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED + + if int((self.share_inputs['seq_lens_this_time'] > 0).sum()) == 0: + break + + def _update_chunked_prefill(self, tasks): + """ + 更新chunked prefill相关参数 + """ + if not self.parallel_config.enable_chunked_prefill: + return + + for task in tasks: + if task.get("prefill_chunk_info", None) is None: + continue + + if task.chunk_idx > len(task.prefill_chunk_info): + continue + self.restore_chunked_prefill_request[task.request_id] = task + + for id, task in list(self.restore_chunked_prefill_request.items()): + idx = task.idx + logger.debug( + f"{task.request_id} chunked prefill {task.chunk_idx}/{len(task.prefill_chunk_info)}" + ) + start_idx = sum(task.prefill_chunk_info[:task.chunk_idx]) + if task.chunk_idx == len(task.prefill_chunk_info): + self.share_inputs["seq_lens_this_time"][idx:idx + 1] = 1 + self.share_inputs['seq_lens_encoder'][idx:idx + 1] = 0 + self.share_inputs["step_idx"][idx:idx + 1] = 1 + self.share_inputs["seq_lens_decoder"][ + idx:idx + 1] = start_idx + task.get("seq_lens_decoder", 0) + del self.restore_chunked_prefill_request[task.request_id] + else: + token_chunk_size = task.prefill_chunk_info[task.chunk_idx] + + self.share_inputs["seq_lens_this_time"][idx:idx + + 1] = token_chunk_size + self.share_inputs['input_ids'][ + idx, :token_chunk_size] = np.array( + task.prompt_token_ids[start_idx:start_idx + + token_chunk_size]) + self.share_inputs['seq_lens_encoder'][idx:idx + + 1] = token_chunk_size + self.share_inputs["step_idx"][idx:idx + 1] = 0 + self.share_inputs["seq_lens_decoder"][ + idx:idx + 1] = start_idx + task.get("seq_lens_decoder", 0) + if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled( + ): + self.proposer.update_task_chunk_prefill(task) + task.chunk_idx += 1 + + def _dummy_sampler_run(self) -> paddle.Tensor: + """ """ + pass + + def capture_model(self) -> None: + """ + Trigger CUDA Graph capture for all shapes in 'CudaGraphConfig.cudagraph_capture_sizes' + """ + if not self.use_cudagraph: + logger.info( + "Skipping CUDA graph capture. Please check GraphOptimizationConfig" + ) + return + time_before_capture = time.perf_counter() + expected_decode_len = 1 + capture_sizes = self.cudagraph_capture_sizes.copy() + for batch_size in sorted(capture_sizes, reverse=True): + self._dummy_run(num_tokens=self.parallel_config.max_model_len, + batch_size=batch_size, + in_capturing=True, + expected_decode_len=expected_decode_len) + logger.info( + f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}" + ) + + time_after_capture = time.perf_counter() + logger.info( + f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds" + ) + + def _get_skip_idx(self, model_forward_batch): + """ + Get the index of the request that needs to be skipped during execution. + Args: + model_forward_batch: A list of requests to be executed by this runner. + Returns: + A list of indices corresponding to the requests that need to be skipped. + """ + skip_idx_list = [] + if not self.parallel_config.enable_chunked_prefill or self.guided_backend is None: + return skip_idx_list + + for task in model_forward_batch: + if task.get("prefill_chunk_info", + None) is None or task.chunk_idx >= len( + task.prefill_chunk_info): + continue + skip_idx_list.append(task.idx) + + for task in self.restore_chunked_prefill_request.values(): + if task.idx in skip_idx_list or task.chunk_idx >= len( + task.prefill_chunk_info): + continue + skip_idx_list.append(task.idx) + + return skip_idx_list + + def execute_model( + self, + model_forward_batch: Optional[List[Request]] = None, + ) -> Optional[ModelRunnerOutput]: + """ + The Entrance of model execute. + Args: + model_forward_batch: 'Request' contains information related to prompt and is an abstract + class at the server level, which is too granular for ModelRunner. + We plan to replace it with 'ModelForwardBatch'. + intermediate_tensors: + """ + # If `not_need_stop`` is False, it means the current worker is in an idle state. + # This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode, + # when there is data on other runner, the current runner is required to execute part of the model. + if not self.not_need_stop(): + self._execute_empty_input() + return None + + # 1. Prepare inputs of model and decoder. + # sampler create async operation + skip_idx_list = self._get_skip_idx(model_forward_batch) + self._prepare_inputs() + self.sampler.pre_process(skip_idx_list) + + # 2. Padding inputs for cuda grph + + # 3. Execute model + is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] + > 1).sum() > 0) + self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch + self.forward_meta.is_decode_batch = is_decode_batch + model_output = self.model( + ids_remove_padding=self.share_inputs["ids_remove_padding"], + forward_meta=self.forward_meta) + + hiddden_states = rebuild_padding( + model_output, + self.share_inputs["cum_offsets"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["seq_lens_encoder"], + self.share_inputs["output_padding_offset"] + if self.speculative_decoding else None, + self.parallel_config.max_model_len, + ) + + # 4. Compute logits, Sample + logits = self.model.compute_logits(hiddden_states) + + if not self.speculative_decoding: + set_value_by_flags_and_idx( + self.share_inputs["pre_ids"], + self.share_inputs["input_ids"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_encoder"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["step_idx"], + self.share_inputs["stop_flags"], + ) + sampled_token_ids = self.sampler( + logits, + self.sampling_metadata, + skip_idx_list, + ) + if self.parallel_config.tensor_parallel_degree > 1: + paddle.distributed.broadcast(sampled_token_ids, 0) + + else: + self.sampler(logits, self.sampling_metadata, + self.parallel_config.max_model_len, self.share_inputs) + sampled_token_ids = None + if self.parallel_config.tensor_parallel_degree > 1: + paddle.distributed.broadcast( + self.share_inputs["accept_tokens"], 0) + paddle.distributed.broadcast(self.share_inputs["accept_num"], + 0) + paddle.distributed.broadcast(self.share_inputs["step_idx"], 0) + paddle.distributed.broadcast(self.share_inputs["stop_flags"], + 0) + + # 5. Post Process + model_output_data = ModelOutputData( + next_tokens=self.share_inputs["next_tokens"], + stop_flags=self.share_inputs["stop_flags"], + step_idx=self.share_inputs["step_idx"], + max_dec_len=self.share_inputs["max_dec_len"], + pre_ids=self.share_inputs["pre_ids"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + eos_token_id=self.share_inputs["eos_token_id"], + not_need_stop=self.share_inputs["not_need_stop"], + input_ids=self.share_inputs["input_ids"], + stop_nums=self.share_inputs["stop_nums"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + is_block_step=self.share_inputs["is_block_step"], + full_hidden_states=model_output, + msg_queue_id=self.parallel_config.msg_queue_id, + mp_rank=self.local_rank, + use_ep=self.parallel_config.use_ep, + draft_tokens=self.share_inputs["draft_tokens"] + if self.speculative_decoding else None, + actual_draft_token_num=self.share_inputs["actual_draft_token_num"] + if self.speculative_decoding else None, + accept_tokens=self.share_inputs["accept_tokens"] + if self.speculative_decoding else None, + accept_num=self.share_inputs["accept_num"] + if self.speculative_decoding else None) + + if self.speculative_config.method in ["mtp"] and \ + self.parallel_config.splitwise_role == "prefill": + skip_save_output = True + else: + skip_save_output = False + post_process(sampled_token_ids=sampled_token_ids, + model_output=model_output_data, + save_each_rank=self.parallel_config.use_ep, + speculative_decoding=self.speculative_decoding, + skip_save_output=skip_save_output) + + # 6. Speculative decode + if self.speculative_decoding: + if self.speculative_method == "mtp": + self.proposer.run(full_hidden_states=model_output) + else: + self.proposer.run(share_inputs=self.share_inputs) + + # 7. Updata 'infer_seed' and step_cuda() + self.share_inputs["infer_seed"].add_(self.infer_seed_increment) + self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED + + self._update_chunked_prefill(model_forward_batch) + self._add_cache(model_forward_batch) + return None + + def _add_cache(self, model_forward_batch) -> None: + """ + Add cache for guided decoding. + """ + if self.guided_backend is None: + return + + for request in model_forward_batch: + logits_cached = request.get("logits_cached", None) + if logits_cached is None or logits_cached: + continue + + request.logits_cached = True + if isinstance(request.logits_processor, LogitsProcessorBase): + self.guided_backend.add_cache(request.schemata_key, + request.logits_processor) + else: + self.guided_backend.add_cache( + request.schemata_key, request.logits_processor.result()) + + def _execute_empty_input(self) -> None: + """ + In certain scenarios, such as during EP, + the runner needs to execute partial modules of the model without input data. + This requires the model to implement the `empty_input_forward` method. + """ + if hasattr(self.model, "empty_input_forward"): + self.model.empty_input_forward() + else: + raise ValueError( + f"{type(self.model)} has no attribute 'empty_input_forward") + + def profile_run(self) -> None: + """Execute a forward pass with dummy inputs to profile the memory usage of the model.""" + + # Initialize kv cache for profile run. After profile run kv cache will be reset. + self.num_gcu_blocks = self.parallel_config.max_block_num + self.initialize_kv_cache() + + # 1. Profile with multimodal encoder & encoder cache + + # 2. Dummy run + self._dummy_run(num_tokens=self.parallel_config.max_num_batched_tokens, + batch_size=min(self.parallel_config.max_num_seqs, 3)) + + # 3. gc + self.clear_cache() + + if self.speculative_method in ["mtp"]: + self.proposer.clear_dummy_input() + # paddle.device.cuda.synchronize() + + def update_share_input_block_num(self, num_gpu_blocks: int) -> None: + """ + Set a globally unified block number and update the model's shared input. + Args: + num_gpu_blocks: + """ + self.num_gcu_blocks = num_gpu_blocks + + # Reset block table and kv cache with global block num + if not (self.parallel_config.enable_prefix_caching \ + or self.parallel_config.splitwise_role != "mixed"): + self.initialize_kv_cache() + + # Reset free list + free_list = list( + range( + self.num_gcu_blocks - 1, + int(self.num_gcu_blocks * self.parallel_config.kv_cache_ratio) + - 1, -1)) + self.free_list_len = len(free_list) + self.share_inputs.update({ + "free_list": + paddle.to_tensor(free_list, dtype="int32"), + "free_list_len": + paddle.full([1], self.free_list_len, dtype="int32"), + }) + + self.parallel_config.do_profile = False + + if self.speculative_method in ["mtp"]: + self.proposer.update_block_num(num_gpu_blocks) + + def cal_theortical_kvcache(self): + """ + Calculate the total block memory required at the model level + """ + """ + Byte of dtype: + - default(bf16): 2 + - cache_int8: 1 + - cache_int4: + """ + cache_quant_dtype = None + if (self.quant_config + and hasattr(self.quant_config, "kv_cache_quant_type") + and self.quant_config.kv_cache_quant_type is not None): + cache_quant_dtype = self.quant_config.kv_cache_quant_type + + if cache_quant_dtype is not None: # int8, int8_zp, fp8, fp8_zp + byte_of_dtype = 1 + else: # default + byte_of_dtype = 2 + + hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads + num_layers = self.model_config.num_layers + \ + self.speculative_config.num_gpu_block_expand_ratio if \ + self.speculative_method in [ + "mtp" + ] else self.model_config.num_layers + required_memory = ( + byte_of_dtype * 2 * # k + v + (self.parallel_config.block_size * hidden_dim) * num_layers) + return required_memory + + def not_need_stop(self) -> bool: + """ """ + return self.share_inputs["not_need_stop"][0] diff --git a/fastdeploy/worker/gcu_worker.py b/fastdeploy/worker/gcu_worker.py new file mode 100644 index 000000000..f280084de --- /dev/null +++ b/fastdeploy/worker/gcu_worker.py @@ -0,0 +1,142 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import gc +from typing import List, Optional + +import paddle +import paddle.nn as nn + +from fastdeploy.config import FDConfig +from fastdeploy.engine.request import Request +from fastdeploy.utils import get_logger +from fastdeploy.worker.gcu_model_runner import GCUModelRunner +from fastdeploy.worker.output import ModelRunnerOutput +from fastdeploy.worker.worker_base import WorkerBase + +logger = get_logger("gcu_worker", "gcu_worker.log") + + +class GcuWorker(WorkerBase): + """ """ + + def __init__( + self, + fd_config: FDConfig, + local_rank: int, + rank: int, + ): + super().__init__( + fd_config=fd_config, + local_rank=local_rank, + rank=rank, + ) + pass + + def init_device(self): + """ Initialize device and Construct model runner + """ + if paddle.is_compiled_with_custom_device("gcu"): + # Set evironment variable + self.device_ids = self.parallel_config.device_ids.split(",") + self.device = f"gcu:{self.local_rank}" + paddle.device.set_device(self.device) + paddle.set_default_dtype(self.parallel_config.dtype) + logger.info(f"GcuWorker init_device:{self.device}, device_ids:{self.device_ids}") + + gc.collect() + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + + # Construct model runner + self.model_runner: GCUModelRunner = GCUModelRunner( + fd_config=self.fd_config, + device=self.device, + device_id=self.device_ids[self.local_rank], + rank=self.rank, + local_rank=self.local_rank) + + def prefill_finished(self): + """ + check whether prefill stage finished + """ + return self.model_runner.prefill_finished() + + def determine_available_memory(self) -> int: + """ + Profiles the peak memory usage of the model to determine how much + memory can be used for KV cache without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GCU and CPU blocks + that can be allocated with the remaining free memory. + + Tip: + You may limit the usage of GCU memory + by adjusting the `gcu_memory_utilization` parameter. + """ + raise NotImplementedError + + def load_model(self) -> None: + """ """ + self.model_runner.load_model() + + def get_model(self) -> nn.Layer: + """ """ + return self.model_runner.get_model() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """ """ + pass + + def execute_model( + self, + model_forward_batch: Optional[List[Request]] = None, + ) -> Optional[ModelRunnerOutput]: + """ """ + output = self.model_runner.execute_model(model_forward_batch) + return output + + def preprocess_new_task(self, req_dicts: List[Request]) -> None: + """ Process new requests and then start the decode loop + TODO(gongshaotian):The scheduler should schedule the handling of prefill, + and workers and modelrunners should not perceive it. + """ + self.model_runner.insert_prefill_inputs(req_dicts=req_dicts) + + def graph_optimize_and_warm_up_model(self) -> None: + """ + Perform the warm-up and the graph optimization + """ + # 1. Warm up model + # NOTE(gongshaotian): may be not need warm_up at this place + + # 2. Triger cuda grpah capture + self.model_runner.capture_model() + + def check_health(self) -> bool: + """ """ + return True + + def cal_theortical_kvcache(self) -> int: + """ """ + return self.model_runner.cal_theortical_kvcache() + + def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None: + """ """ + self.model_runner.update_share_input_block_num( + num_gpu_blocks=num_gpu_blocks) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 19082a682..0591b7b91 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -53,6 +53,9 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase: return IluvatarWorker(fd_config=fd_config, local_rank=local_rank, rank=rank) + if current_platform.is_gcu(): + from fastdeploy.worker.gcu_worker import GcuWorker + return GcuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank) class PaddleDisWorkerProc(): diff --git a/setup.py b/setup.py index 0447388a4..f25f7237f 100644 --- a/setup.py +++ b/setup.py @@ -167,6 +167,8 @@ def get_device_type(): return "npu" elif paddle.is_compiled_with_custom_device('iluvatar_gpu'): return "iluvatar-gpu" + elif paddle.is_compiled_with_custom_device('gcu'): + return "gcu" else: return "cpu" @@ -199,7 +201,7 @@ setup( "model_executor/ops/xpu/libs/*", "model_executor/ops/npu/*", "model_executor/ops/base/*", "model_executor/ops/iluvatar/*", "model_executor/models/*", "model_executor/layers/*", - "input/mm_processor/utils/*", + "input/mm_processor/utils/*", "model_executor/ops/gcu/*", "version.txt" ] },