mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
[Intel HPU] Support intel hpu platform (#4161)
* [Intel HPU] Support intel hpu platform * fix some issues * apply precommit and move AttentionBackend_HPU * fix format issue * correct ops import * fix ci issue * update code in layers * fix code style issue * remove dense tp moe ep mode * fix enc_dec_block_num * fix rebase issue * rename hpu to gaudi in readme * rename ForwardMeta_HPU to HPUForwardMeta
This commit is contained in:
@@ -43,7 +43,7 @@ English | [简体中文](README_CN.md)
|
|||||||
- 🤝 **OpenAI API Server and vLLM Compatible**: One-command deployment with [vLLM](https://github.com/vllm-project/vllm/) interface compatibility.
|
- 🤝 **OpenAI API Server and vLLM Compatible**: One-command deployment with [vLLM](https://github.com/vllm-project/vllm/) interface compatibility.
|
||||||
- 🧮 **Comprehensive Quantization Format Support**: W8A16, W8A8, W4A16, W4A8, W2A16, FP8, and more.
|
- 🧮 **Comprehensive Quantization Format Support**: W8A16, W8A8, W4A16, W4A8, W2A16, FP8, and more.
|
||||||
- ⏩ **Advanced Acceleration Techniques**: Speculative decoding, Multi-Token Prediction (MTP) and Chunked Prefill.
|
- ⏩ **Advanced Acceleration Techniques**: Speculative decoding, Multi-Token Prediction (MTP) and Chunked Prefill.
|
||||||
- 🖥️ **Multi-Hardware Support**: NVIDIA GPU, Kunlunxin XPU, Hygon DCU, Ascend NPU, Iluvatar GPU, Enflame GCU, MetaX GPU etc.
|
- 🖥️ **Multi-Hardware Support**: NVIDIA GPU, Kunlunxin XPU, Hygon DCU, Ascend NPU, Iluvatar GPU, Enflame GCU, MetaX GPU, Intel Gaudi etc.
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
@@ -60,6 +60,7 @@ FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**,
|
|||||||
- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md)
|
- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md)
|
||||||
- [Hygon DCU](./docs/get_started/installation/hygon_dcu.md)
|
- [Hygon DCU](./docs/get_started/installation/hygon_dcu.md)
|
||||||
- [MetaX GPU](./docs/get_started/installation/metax_gpu.md)
|
- [MetaX GPU](./docs/get_started/installation/metax_gpu.md)
|
||||||
|
- [Intel Gaudi](./docs/get_started/installation/intel_gaudi.md)
|
||||||
|
|
||||||
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU are currently under development and testing. Stay tuned for updates!
|
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU are currently under development and testing. Stay tuned for updates!
|
||||||
|
|
||||||
|
@@ -41,7 +41,7 @@
|
|||||||
- 🤝 **OpenAI API服务与vLLM兼容**:单命令部署,兼容[vLLM](https://github.com/vllm-project/vllm/)接口
|
- 🤝 **OpenAI API服务与vLLM兼容**:单命令部署,兼容[vLLM](https://github.com/vllm-project/vllm/)接口
|
||||||
- 🧮 **全量化格式支持**:W8A16、W8A8、W4A16、W4A8、W2A16、FP8等
|
- 🧮 **全量化格式支持**:W8A16、W8A8、W4A16、W4A8、W2A16、FP8等
|
||||||
- ⏩ **高级加速技术**:推测解码、多令牌预测(MTP)及分块预填充
|
- ⏩ **高级加速技术**:推测解码、多令牌预测(MTP)及分块预填充
|
||||||
- 🖥️ **多硬件支持**:NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU等
|
- 🖥️ **多硬件支持**:NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU、英特尔Gaudi等
|
||||||
|
|
||||||
## 要求
|
## 要求
|
||||||
|
|
||||||
@@ -58,6 +58,7 @@ FastDeploy 支持在**英伟达(NVIDIA)GPU**、**昆仑芯(Kunlunxin)XPU
|
|||||||
- [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md)
|
- [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md)
|
||||||
- [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md)
|
- [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md)
|
||||||
- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md)
|
- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md)
|
||||||
|
- [英特尔 Gaudi](./docs/zh/get_started/installation/intel_gaudi.md)
|
||||||
|
|
||||||
**注意:** 我们正在积极拓展硬件支持范围。目前,包括昇腾(Ascend)NPU 等其他硬件平台正在开发测试中。敬请关注更新!
|
**注意:** 我们正在积极拓展硬件支持范围。目前,包括昇腾(Ascend)NPU 等其他硬件平台正在开发测试中。敬请关注更新!
|
||||||
|
|
||||||
|
10
build.sh
10
build.sh
@@ -128,6 +128,12 @@ function copy_ops(){
|
|||||||
echo -e "MACA ops have been copy to fastdeploy"
|
echo -e "MACA ops have been copy to fastdeploy"
|
||||||
return
|
return
|
||||||
fi
|
fi
|
||||||
|
is_intel_hpu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('intel_hpu'))"`
|
||||||
|
if [ "$is_intel_hpu" = "True" ]; then
|
||||||
|
DEVICE_TYPE="intel-hpu"
|
||||||
|
echo -e "intel_hpu ops have been copy to fastdeploy"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
|
||||||
DEVICE_TYPE="cpu"
|
DEVICE_TYPE="cpu"
|
||||||
cd ../../../../
|
cd ../../../../
|
||||||
@@ -159,7 +165,9 @@ function build_and_install_ops() {
|
|||||||
else
|
else
|
||||||
FD_BUILDING_ARCS=${FD_BUILDING_ARCS} ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
|
FD_BUILDING_ARCS=${FD_BUILDING_ARCS} ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
|
||||||
fi
|
fi
|
||||||
find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \;
|
if [ -d "${OPS_TMP_DIR}" ]; then
|
||||||
|
find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \;
|
||||||
|
fi
|
||||||
else
|
else
|
||||||
echo "Error: Invalid parameter '$FD_CPU_USE_BF16'. Please use true or false."
|
echo "Error: Invalid parameter '$FD_CPU_USE_BF16'. Please use true or false."
|
||||||
exit 1
|
exit 1
|
||||||
|
@@ -623,6 +623,8 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
elif paddle.is_compiled_with_custom_device("intel_hpu"):
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
|
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
|
||||||
|
|
||||||
|
@@ -7,3 +7,4 @@ FastDeploy currently supports installation on the following hardware platforms:
|
|||||||
- [Enflame S60 GCU Installation](Enflame_gcu.md)
|
- [Enflame S60 GCU Installation](Enflame_gcu.md)
|
||||||
- [Iluvatar GPU Installation](iluvatar_gpu.md)
|
- [Iluvatar GPU Installation](iluvatar_gpu.md)
|
||||||
- [Hygon DCU Installation](hygon_dcu.md)
|
- [Hygon DCU Installation](hygon_dcu.md)
|
||||||
|
- [Intel Gaudi Installation](intel_gaudi.md)
|
||||||
|
75
docs/get_started/installation/intel_gaudi.md
Normal file
75
docs/get_started/installation/intel_gaudi.md
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
# Intel Gaudi Installation for running ERNIE 4.5 Series Models
|
||||||
|
|
||||||
|
The following installation methods are available when your environment meets these requirements:
|
||||||
|
|
||||||
|
- Python 3.10
|
||||||
|
- Intel Gaudi 2
|
||||||
|
- Intel Gaudi software version 1.22.0
|
||||||
|
- Linux X86_64
|
||||||
|
|
||||||
|
## 1. Run Docker Container
|
||||||
|
|
||||||
|
Use the following commands to run a Docker container. Make sure to update the versions below as listed in the [Support Matrix](https://docs.habana.ai/en/latest/Support_Matrix/Support_Matrix.html):
|
||||||
|
|
||||||
|
```{.console}
|
||||||
|
$ docker pull vault.habana.ai/gaudi-docker/1.22.0/ubuntu22.04/habanalabs/pytorch-installer-2.7.1:latest
|
||||||
|
$ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.22.0/ubuntu22.04/habanalabs/pytorch-installer-2.7.1:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Install PaddlePaddle
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m pip install paddlepaddle==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Install PaddleCustomDevice
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/PaddlePaddle/PaddleCustomDevice
|
||||||
|
cd PaddleCustomDevice/backends/intel_hpu/
|
||||||
|
mkdir -p build
|
||||||
|
cd build
|
||||||
|
cmake ..
|
||||||
|
make -j
|
||||||
|
pip install --force-reinstall dist/paddle_intel_hpu*.whl
|
||||||
|
cd PaddleCustomDevice/backends/intel_hpu/custom_ops
|
||||||
|
python setup.py install
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Install FastDeploy
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/PaddlePaddle/FastDeploy
|
||||||
|
cd FastDeploy
|
||||||
|
bash build.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
## Prepare the inference demo
|
||||||
|
|
||||||
|
### 1. Start inference service
|
||||||
|
```shell
|
||||||
|
export GC_KERNEL_PATH=/usr/lib/habanalabs/libtpc_kernels.so
|
||||||
|
export GC_KERNEL_PATH=/usr/local/lib/python3.10/dist-packages/paddle_custom_device/intel_hpu/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH
|
||||||
|
export INTEL_HPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||||
|
export PADDLE_DISTRI_BACKEND=xccl
|
||||||
|
export PADDLE_XCCL_BACKEND=intel_hpu
|
||||||
|
export HABANA_PROFILE=0
|
||||||
|
export HPU_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
|
HPU_WARMUP_BUCKET=1 HPU_WARMUP_MODEL_LEN=4096 FD_ATTENTION_BACKEND=HPU_ATTN python -m fastdeploy.entrypoints.openai.api_server --model ERNIE-4.5-21B-A3B-Paddle --tensor-parallel-size 1 --max-model-len 32768 --max-num-seqs 128
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Launch the request
|
||||||
|
```bash
|
||||||
|
curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What is AI?"}
|
||||||
|
], "max_tokens": 24
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Successfully returns the result
|
||||||
|
```json
|
||||||
|
{"id":"chatcmpl-3bd98ae2-fafe-46ae-a552-d653a8526503","object":"chat.completion","created":1757653575,"model":"ERNIE-4.5-21B-A3B-Paddle","choices":[{"index":0,"message":{"role":"assistant","content":"**AI (Artificial Intelligence)** refers to the development of computer systems that can perform tasks typically requiring human intelligence.","multimodal_content":null,"reasoning_content":null,"tool_calls":null,"prompt_token_ids":null,"completion_token_ids":null,"text_after_process":null,"raw_prediction":null,"prompt_tokens":null,"completion_tokens":null},"logprobs":null,"finish_reason":"length"}],"usage":{"prompt_tokens":11,"total_tokens":35,"completion_tokens":24,"prompt_tokens_details":{"cached_tokens":0}}}
|
||||||
|
```
|
@@ -7,3 +7,4 @@ FastDeploy支持如下硬件平台:
|
|||||||
- [Enflame S60 GCU Installation](Enflame_gcu.md)
|
- [Enflame S60 GCU Installation](Enflame_gcu.md)
|
||||||
- [Iluvatar GPU Installation](iluvatar_gpu.md)
|
- [Iluvatar GPU Installation](iluvatar_gpu.md)
|
||||||
- [Hygon DCU Installation](hygon_dcu.md)
|
- [Hygon DCU Installation](hygon_dcu.md)
|
||||||
|
- [Intel Gaudi Installation](intel_gaudi.md)
|
||||||
|
75
docs/zh/get_started/installation/intel_gaudi.md
Normal file
75
docs/zh/get_started/installation/intel_gaudi.md
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
# 使用 Intel Gaudi 运行ERNIE 4.5 系列模型
|
||||||
|
|
||||||
|
在环境满足如下条件前提下
|
||||||
|
|
||||||
|
- Python 3.10
|
||||||
|
- Intel Gaudi 2
|
||||||
|
- Intel Gaudi software version 1.22.0
|
||||||
|
- Linux X86_64
|
||||||
|
|
||||||
|
## 1. 运行Docker容器
|
||||||
|
|
||||||
|
使用下面命令运行Docker容器. 确保更新的版本在如下列表中 [Support Matrix](https://docs.habana.ai/en/latest/Support_Matrix/Support_Matrix.html):
|
||||||
|
|
||||||
|
```{.console}
|
||||||
|
$ docker pull vault.habana.ai/gaudi-docker/1.22.0/ubuntu22.04/habanalabs/pytorch-installer-2.7.1:latest
|
||||||
|
$ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.22.0/ubuntu22.04/habanalabs/pytorch-installer-2.7.1:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 安装 PaddlePaddle
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m pip install paddlepaddle==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 安装 PaddleCustomDevice
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/PaddlePaddle/PaddleCustomDevice
|
||||||
|
cd PaddleCustomDevice/backends/intel_hpu/
|
||||||
|
mkdir -p build
|
||||||
|
cd build
|
||||||
|
cmake ..
|
||||||
|
make -j
|
||||||
|
pip install --force-reinstall dist/paddle_intel_hpu*.whl
|
||||||
|
cd PaddleCustomDevice/backends/intel_hpu/custom_ops
|
||||||
|
python setup.py install
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 安装 FastDeploy
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/PaddlePaddle/FastDeploy
|
||||||
|
cd FastDeploy
|
||||||
|
bash build.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
## 准备推理示例
|
||||||
|
|
||||||
|
### 1. 启动推理服务
|
||||||
|
```shell
|
||||||
|
export GC_KERNEL_PATH=/usr/lib/habanalabs/libtpc_kernels.so
|
||||||
|
export GC_KERNEL_PATH=/usr/local/lib/python3.10/dist-packages/paddle_custom_device/intel_hpu/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH
|
||||||
|
export INTEL_HPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||||
|
export PADDLE_DISTRI_BACKEND=xccl
|
||||||
|
export PADDLE_XCCL_BACKEND=intel_hpu
|
||||||
|
export HABANA_PROFILE=0
|
||||||
|
export HPU_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
|
HPU_WARMUP_BUCKET=1 HPU_WARMUP_MODEL_LEN=4096 FD_ATTENTION_BACKEND=HPU_ATTN python -m fastdeploy.entrypoints.openai.api_server --model ERNIE-4.5-21B-A3B-Paddle --tensor-parallel-size 1 --max-model-len 32768 --max-num-seqs 128
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 发送请求
|
||||||
|
```bash
|
||||||
|
curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What is AI?"}
|
||||||
|
], "max_tokens": 24
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 成功返回结果
|
||||||
|
```json
|
||||||
|
{"id":"chatcmpl-3bd98ae2-fafe-46ae-a552-d653a8526503","object":"chat.completion","created":1757653575,"model":"ERNIE-4.5-21B-A3B-Paddle","choices":[{"index":0,"message":{"role":"assistant","content":"**AI (Artificial Intelligence)** refers to the development of computer systems that can perform tasks typically requiring human intelligence.","multimodal_content":null,"reasoning_content":null,"tool_calls":null,"prompt_token_ids":null,"completion_token_ids":null,"text_after_process":null,"raw_prediction":null,"prompt_tokens":null,"completion_tokens":null},"logprobs":null,"finish_reason":"length"}],"usage":{"prompt_tokens":11,"total_tokens":35,"completion_tokens":24,"prompt_tokens_details":{"cached_tokens":0}}}
|
||||||
|
```
|
@@ -1498,6 +1498,8 @@ class FDConfig:
|
|||||||
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
|
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
|
||||||
if current_platform.is_xpu():
|
if current_platform.is_xpu():
|
||||||
self.device_ids = os.getenv("XPU_VISIBLE_DEVICES", self.device_ids)
|
self.device_ids = os.getenv("XPU_VISIBLE_DEVICES", self.device_ids)
|
||||||
|
if current_platform.is_intel_hpu():
|
||||||
|
self.device_ids = os.getenv("HPU_VISIBLE_DEVICES", self.device_ids)
|
||||||
|
|
||||||
self.read_from_config()
|
self.read_from_config()
|
||||||
self.postprocess()
|
self.postprocess()
|
||||||
|
@@ -66,3 +66,26 @@ try:
|
|||||||
|
|
||||||
except:
|
except:
|
||||||
tensor_model_parallel_all_reduce = None
|
tensor_model_parallel_all_reduce = None
|
||||||
|
|
||||||
|
from paddle.distributed.communication import stream
|
||||||
|
from paddle.distributed.communication.reduce import ReduceOp
|
||||||
|
|
||||||
|
|
||||||
|
def all_reduce(
|
||||||
|
tensor,
|
||||||
|
op,
|
||||||
|
group,
|
||||||
|
sync_op: bool = True,
|
||||||
|
):
|
||||||
|
return stream.all_reduce(tensor, op=op, group=group, sync_op=sync_op, use_calc_stream=True)
|
||||||
|
|
||||||
|
|
||||||
|
@paddle.jit.marker.unified
|
||||||
|
def tensor_model_parallel_all_reduce_custom(input_: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
"""All-reduce the input tensor across model parallel group on calc stream."""
|
||||||
|
if paddle.in_dynamic_mode():
|
||||||
|
hcg = dist.fleet.get_hybrid_communicate_group()
|
||||||
|
mp_group = hcg.get_model_parallel_group()
|
||||||
|
all_reduce(input_, op=ReduceOp.SUM, group=mp_group)
|
||||||
|
else:
|
||||||
|
dist.all_reduce(input_)
|
||||||
|
@@ -17,12 +17,14 @@
|
|||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Dict, Optional
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
from fastdeploy.model_executor.layers.attention import AttentionBackend
|
from fastdeploy.model_executor.layers.attention import AttentionBackend
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from fastdeploy.model_executor.layers.attention import AttentionBackend_HPU
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -240,3 +242,116 @@ class DCUForwardMeta(ForwardMeta):
|
|||||||
|
|
||||||
# Accumulated offset
|
# Accumulated offset
|
||||||
cum_offsets: Optional[paddle.Tensor] = None
|
cum_offsets: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HPUForwardMeta:
|
||||||
|
"""
|
||||||
|
HPUForwardMeta is used to store the global meta information of the forward on intel HPU.
|
||||||
|
"""
|
||||||
|
|
||||||
|
#
|
||||||
|
input_ids: paddle.Tensor
|
||||||
|
|
||||||
|
# attention meta
|
||||||
|
forward_mode: ForwardMode = ForwardMode.MIXED
|
||||||
|
|
||||||
|
#
|
||||||
|
ids_remove_padding: paddle.Tensor = None
|
||||||
|
|
||||||
|
#
|
||||||
|
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
#
|
||||||
|
seq_lens_decoder: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
#
|
||||||
|
seq_lens_this_time: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
#
|
||||||
|
cum_offsets: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
#
|
||||||
|
block_tables: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
#
|
||||||
|
block_groups: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
#
|
||||||
|
block_list: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
#
|
||||||
|
block_indices: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
#
|
||||||
|
block_offsets: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
#
|
||||||
|
block_mapping: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
#
|
||||||
|
attention_mask: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
#
|
||||||
|
block_size: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
#
|
||||||
|
batch_ids: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
#
|
||||||
|
total_batch: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
#
|
||||||
|
is_prompt: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
#
|
||||||
|
attn_backend: "AttentionBackend_HPU" = None
|
||||||
|
|
||||||
|
#
|
||||||
|
rotary_embs: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
#
|
||||||
|
caches: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
#
|
||||||
|
attn_mask: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
#
|
||||||
|
pre_caches_length: int = 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def init_forward_meta(cls, share_inputs: Dict, attn_backend: "AttentionBackend_HPU"):
|
||||||
|
"""init forward meta"""
|
||||||
|
# TODO(gongshaotian): delete this func
|
||||||
|
is_prompt = share_inputs["is_prompt"]
|
||||||
|
forward_mode = ForwardMode.DECODE
|
||||||
|
if is_prompt:
|
||||||
|
forward_mode = ForwardMode.EXTEND
|
||||||
|
ret = cls(
|
||||||
|
forward_mode=forward_mode,
|
||||||
|
input_ids=share_inputs["input_ids"],
|
||||||
|
ids_remove_padding=share_inputs["ids_remove_padding"],
|
||||||
|
seq_lens_encoder=share_inputs["seq_lens_encoder"],
|
||||||
|
seq_lens_decoder=share_inputs["seq_lens_decoder"],
|
||||||
|
seq_lens_this_time=share_inputs["seq_lens_this_time"],
|
||||||
|
block_tables=share_inputs["block_tables"],
|
||||||
|
block_groups=share_inputs["block_groups"],
|
||||||
|
block_list=share_inputs["block_list"],
|
||||||
|
block_indices=share_inputs["block_indices"],
|
||||||
|
block_offsets=share_inputs["block_offsets"],
|
||||||
|
block_mapping=share_inputs["block_mapping"],
|
||||||
|
attention_mask=share_inputs["block_bias"],
|
||||||
|
block_size=share_inputs["block_size"],
|
||||||
|
total_batch=share_inputs["total_batch"],
|
||||||
|
batch_ids=share_inputs["batch_ids"],
|
||||||
|
is_prompt=share_inputs["is_prompt"],
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
rotary_embs=share_inputs["rotary_embs"],
|
||||||
|
caches=share_inputs["caches"],
|
||||||
|
)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def clear_caches(self):
|
||||||
|
"""safe clear caches"""
|
||||||
|
if self.caches:
|
||||||
|
del self.caches
|
||||||
|
@@ -72,6 +72,8 @@ class SiluAndMul(nn.Layer):
|
|||||||
self.forward = self.forward_cuda
|
self.forward = self.forward_cuda
|
||||||
elif current_platform.is_gcu():
|
elif current_platform.is_gcu():
|
||||||
self.forward = self.forward_gcu
|
self.forward = self.forward_gcu
|
||||||
|
elif current_platform.is_intel_hpu():
|
||||||
|
self.forward = self.forward_intel_hpu
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -147,6 +149,16 @@ class SiluAndMul(nn.Layer):
|
|||||||
out = out + self.bias
|
out = out + self.bias
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def forward_intel_hpu(self, x):
|
||||||
|
"""
|
||||||
|
Forward propagation of the custom activation layer.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input tensor to the activation layer.
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor.
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def get_act_fn(act_fn_name: str) -> nn.Layer:
|
def get_act_fn(act_fn_name: str) -> nn.Layer:
|
||||||
"""Get an activation function by name."""
|
"""Get an activation function by name."""
|
||||||
|
@@ -55,3 +55,10 @@ if current_platform.is_maca():
|
|||||||
if hasattr(metax, "__all__"):
|
if hasattr(metax, "__all__"):
|
||||||
globals().update({name: getattr(metax, name) for name in metax.__all__})
|
globals().update({name: getattr(metax, name) for name in metax.__all__})
|
||||||
__all__.extend(metax.__all__)
|
__all__.extend(metax.__all__)
|
||||||
|
|
||||||
|
if current_platform.is_intel_hpu():
|
||||||
|
from . import intel_hpu
|
||||||
|
|
||||||
|
if hasattr(intel_hpu, "__all__"):
|
||||||
|
globals().update({name: getattr(intel_hpu, name) for name in intel_hpu.__all__})
|
||||||
|
__all__.extend(intel_hpu.__all__)
|
||||||
|
@@ -0,0 +1,26 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
intel_hpu backend methods
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .attention.hpu_attn_backend import HPUAttentionBackend
|
||||||
|
from .moe.fused_moe_hpu_backend import HpuMoEMethod, HpuTensorWiseFP8MoEMethod
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"HPUAttentionBackend",
|
||||||
|
"HpuMoEMethod",
|
||||||
|
"HpuTensorWiseFP8MoEMethod",
|
||||||
|
]
|
@@ -0,0 +1,19 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .hpu_attn_backend import HPUAttentionBackend
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"HPUAttentionBackend",
|
||||||
|
]
|
@@ -0,0 +1,314 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from abc import abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from paddle._typing.dtype_like import _DTypeLiteral
|
||||||
|
|
||||||
|
from fastdeploy.config import FDConfig
|
||||||
|
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||||
|
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||||
|
AttentionBackend,
|
||||||
|
AttentionMetadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from fastdeploy.model_executor.forward_meta import HPUForwardMeta
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBackend_HPU(AttentionBackend):
|
||||||
|
"""The base class of attention backends"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def init_attention_metadata(self, forward_meta: HPUForwardMeta):
|
||||||
|
"""Initialize the forward metadata."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src: paddle.Tensor,
|
||||||
|
qkv_proj: QKVParallelLinear,
|
||||||
|
o_proj: RowParallelLinear,
|
||||||
|
layer: paddle.nn.Layer,
|
||||||
|
forward_meta: HPUForwardMeta,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run a forward.
|
||||||
|
args:
|
||||||
|
src: the hidden states tensor
|
||||||
|
residual_input: the residual tensor
|
||||||
|
layer: The layer that will be used for the forward.
|
||||||
|
forward_meta: The forward metadata.
|
||||||
|
"""
|
||||||
|
if forward_meta.forward_mode.is_mixed():
|
||||||
|
return self.forward_mixed(
|
||||||
|
src,
|
||||||
|
qkv_proj,
|
||||||
|
o_proj,
|
||||||
|
layer,
|
||||||
|
forward_meta,
|
||||||
|
)
|
||||||
|
elif forward_meta.forward_mode.is_decode():
|
||||||
|
return self.forward_decode(
|
||||||
|
src,
|
||||||
|
qkv_proj,
|
||||||
|
o_proj,
|
||||||
|
layer,
|
||||||
|
forward_meta,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.forward_extend(
|
||||||
|
src,
|
||||||
|
qkv_proj,
|
||||||
|
o_proj,
|
||||||
|
layer,
|
||||||
|
forward_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_mixed(
|
||||||
|
self,
|
||||||
|
src: paddle.Tensor,
|
||||||
|
qkv_proj: QKVParallelLinear,
|
||||||
|
o_proj: RowParallelLinear,
|
||||||
|
layer: paddle.nn.Layer,
|
||||||
|
forward_meta: HPUForwardMeta,
|
||||||
|
):
|
||||||
|
"""Run a forward for mix."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def forward_decode(
|
||||||
|
self,
|
||||||
|
src: paddle.Tensor,
|
||||||
|
qkv_proj: QKVParallelLinear,
|
||||||
|
o_proj: RowParallelLinear,
|
||||||
|
layer: paddle.nn.Layer,
|
||||||
|
forward_meta: HPUForwardMeta,
|
||||||
|
):
|
||||||
|
"""Run a forward for decode."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def forward_extend(
|
||||||
|
self,
|
||||||
|
src: paddle.Tensor,
|
||||||
|
qkv_proj: QKVParallelLinear,
|
||||||
|
o_proj: RowParallelLinear,
|
||||||
|
layer: paddle.nn.Layer,
|
||||||
|
forward_meta: HPUForwardMeta,
|
||||||
|
):
|
||||||
|
"""Run a forward for extend."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HPUAttentionMetadata(AttentionMetadata):
|
||||||
|
"""
|
||||||
|
HPUAttentionMetadata
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_len_kv: paddle.Tensor = None
|
||||||
|
set_max_lengths: int = -1
|
||||||
|
encoder_batch_ids: paddle.Tensor = None
|
||||||
|
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||||
|
encoder_num_blocks: paddle.Tensor = None
|
||||||
|
kv_batch_ids: paddle.Tensor = None
|
||||||
|
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||||
|
kv_num_blocks: paddle.Tensor = None
|
||||||
|
decoder_batch_ids: paddle.Tensor = None
|
||||||
|
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||||
|
decoder_num_blocks: paddle.Tensor = None
|
||||||
|
|
||||||
|
_dtype: _DTypeLiteral = paddle.bfloat16
|
||||||
|
encoder_max_partition_size: int = 32768
|
||||||
|
max_partition_size: int = 32768
|
||||||
|
block_tables: Optional[paddle.Tensor] = None
|
||||||
|
rotary_embs: Optional[paddle.Tensor] = None
|
||||||
|
attn_mask: Optional[paddle.Tensor] = None
|
||||||
|
encoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||||
|
decoder_block_shape_q: Optional[paddle.Tensor] = None
|
||||||
|
_fuse_kernel_compute_dtype: str = "bf16"
|
||||||
|
|
||||||
|
# pd_disaggregation
|
||||||
|
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||||
|
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class HPUAttentionBackend(AttentionBackend_HPU):
|
||||||
|
"""
|
||||||
|
HPUAttentionBackend backend implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, llm_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int):
|
||||||
|
"""
|
||||||
|
HPUAttentionBackend __init__
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.attention_metadata: HPUAttentionMetadata = None
|
||||||
|
# TODO(gongshaotian): Use llm_config parameters in the correct location
|
||||||
|
self.block_size = llm_config.parallel_config.block_size
|
||||||
|
self.max_seq_len = llm_config.parallel_config.max_model_len
|
||||||
|
self.rope_theta = 10000.0 if llm_config.model_config.rope_theta is None else llm_config.model_config.rope_theta
|
||||||
|
self.rope_3d = getattr(llm_config.model_config, "rope_3d", False)
|
||||||
|
self.causal = getattr(llm_config.model_config, "causal", True)
|
||||||
|
self.speculative_method: str = llm_config.speculative_config.method
|
||||||
|
self.use_speculate: bool = self.speculative_method is not None
|
||||||
|
self.speculate_max_draft_token_num: int = llm_config.speculative_config.num_speculative_tokens
|
||||||
|
self.keep_pd_step_flag: bool = llm_config.speculative_config.model_type == "mtp"
|
||||||
|
self.rank: int = llm_config.parallel_config.tensor_parallel_rank
|
||||||
|
self.nranks = llm_config.parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
|
self.kv_num_heads = kv_num_heads
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.num_layers = llm_config.model_config.num_hidden_layers
|
||||||
|
|
||||||
|
# pd_disaggregation
|
||||||
|
self.use_pd_disaggregation = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||||
|
self.start_layer_index = llm_config.model_config.start_layer_index
|
||||||
|
|
||||||
|
def init_attention_metadata(self, forward_meta):
|
||||||
|
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||||
|
metadata = HPUAttentionMetadata()
|
||||||
|
metadata.encoder_block_shape_q = 64
|
||||||
|
metadata.decoder_block_shape_q = 16
|
||||||
|
metadata.max_partition_size = 32768
|
||||||
|
metadata.encoder_max_partition_size = 32768
|
||||||
|
metadata._dtype = paddle.get_default_dtype()
|
||||||
|
if metadata._dtype == "bfloat16":
|
||||||
|
metadata._fuse_kernel_compute_dtype = "bf16"
|
||||||
|
elif metadata._dtype == "float16":
|
||||||
|
metadata._fuse_kernel_compute_dtype = "fp16"
|
||||||
|
elif metadata._dtype == "float32":
|
||||||
|
metadata._fuse_kernel_compute_dtype = "fp32"
|
||||||
|
metadata.block_tables = forward_meta.block_tables
|
||||||
|
metadata.rotary_embs = forward_meta.rotary_embs
|
||||||
|
metadata.attn_mask = forward_meta.attn_mask
|
||||||
|
|
||||||
|
# pd_disaggregation
|
||||||
|
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||||
|
self.attention_metadata = metadata
|
||||||
|
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
self,
|
||||||
|
max_num_blocks: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Caculate kv cache shape
|
||||||
|
"""
|
||||||
|
return (max_num_blocks, self.block_size, self.kv_num_heads, self.head_dim)
|
||||||
|
|
||||||
|
def forward_extend(
|
||||||
|
self, src, qkv_proj: QKVParallelLinear, o_proj: RowParallelLinear, layer: Attention, forward_meta
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
forward_extend
|
||||||
|
"""
|
||||||
|
# metadata = self.attention_metadata
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.ops.intel_hpu import (
|
||||||
|
fused_qkv_rope,
|
||||||
|
fused_sdpa_proj_t,
|
||||||
|
index_copy_,
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states, key_value_states = fused_qkv_rope(
|
||||||
|
src,
|
||||||
|
qkv_proj.weight,
|
||||||
|
qkv_proj.bias,
|
||||||
|
forward_meta.rotary_embs,
|
||||||
|
self.head_dim,
|
||||||
|
self.num_heads,
|
||||||
|
forward_meta.total_batch,
|
||||||
|
transpose=False,
|
||||||
|
use_neox_style=layer.use_neox_rotary_style,
|
||||||
|
)
|
||||||
|
|
||||||
|
kv, B, BP_BS, M, H = key_value_states.shape
|
||||||
|
key_value_states_reshape = key_value_states.reshape([kv, -1, forward_meta.block_size, M, H])
|
||||||
|
key_states = key_value_states_reshape[0]
|
||||||
|
value_states = key_value_states_reshape[1]
|
||||||
|
k_cache = forward_meta.caches[2 * layer.layer_id]
|
||||||
|
v_cache = forward_meta.caches[2 * layer.layer_id + 1]
|
||||||
|
index_copy_(k_cache, forward_meta.block_indices, key_states, 0)
|
||||||
|
index_copy_(v_cache, forward_meta.block_indices, value_states, 0)
|
||||||
|
|
||||||
|
out_linear_out = fused_sdpa_proj_t(
|
||||||
|
query_states,
|
||||||
|
key_value_states,
|
||||||
|
forward_meta.attn_mask,
|
||||||
|
None,
|
||||||
|
o_proj.weight,
|
||||||
|
scaling_factor=self.head_dim**-0.5,
|
||||||
|
causal=True,
|
||||||
|
softmax_mode=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.nranks > 1:
|
||||||
|
from fastdeploy.distributed.communication import (
|
||||||
|
tensor_model_parallel_all_reduce_custom,
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor_model_parallel_all_reduce_custom(out_linear_out)
|
||||||
|
|
||||||
|
return out_linear_out
|
||||||
|
|
||||||
|
def forward_decode(
|
||||||
|
self, src, qkv_proj: QKVParallelLinear, o_proj: RowParallelLinear, layer: Attention, forward_meta
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
forward_decode
|
||||||
|
"""
|
||||||
|
# metadata = self.attention_metadata
|
||||||
|
from fastdeploy.model_executor.ops.intel_hpu import fused_block_attention
|
||||||
|
|
||||||
|
res = fused_block_attention(
|
||||||
|
src,
|
||||||
|
forward_meta.rotary_embs,
|
||||||
|
forward_meta.caches[2 * layer.layer_id],
|
||||||
|
forward_meta.caches[2 * layer.layer_id + 1],
|
||||||
|
forward_meta.block_groups,
|
||||||
|
forward_meta.block_list,
|
||||||
|
forward_meta.block_mapping,
|
||||||
|
forward_meta.attention_mask,
|
||||||
|
forward_meta.block_indices,
|
||||||
|
forward_meta.block_offsets,
|
||||||
|
qkv_proj.weight,
|
||||||
|
qkv_proj.bias,
|
||||||
|
o_proj.weight,
|
||||||
|
self.head_dim,
|
||||||
|
self.num_heads,
|
||||||
|
scaling_factor=self.head_dim**-0.5,
|
||||||
|
transpose=False,
|
||||||
|
use_neox_style=layer.use_neox_rotary_style,
|
||||||
|
)
|
||||||
|
|
||||||
|
# all_reduce
|
||||||
|
if self.nranks > 1:
|
||||||
|
from fastdeploy.distributed.communication import (
|
||||||
|
tensor_model_parallel_all_reduce_custom,
|
||||||
|
)
|
||||||
|
|
||||||
|
tensor_model_parallel_all_reduce_custom(res)
|
||||||
|
return res
|
@@ -0,0 +1,16 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" "
|
||||||
|
intel_hpu moe
|
||||||
|
"""
|
@@ -0,0 +1,249 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce_custom
|
||||||
|
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
|
||||||
|
|
||||||
|
|
||||||
|
class HpuMoEMethod(MoEMethodBase):
|
||||||
|
"""
|
||||||
|
Use Cutlass Group Gemm to compute Fused MoE.
|
||||||
|
This method is the oldest way to compute MoE in Paddle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||||
|
# TODO: split create_parameter from process_loaded_weights
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||||
|
"""
|
||||||
|
Paddle HPU load weight process.
|
||||||
|
"""
|
||||||
|
# bf16
|
||||||
|
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||||
|
|
||||||
|
for idx, weights_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||||
|
weights_list = []
|
||||||
|
for i in range(layer.num_local_experts):
|
||||||
|
weight_tensor = weights_tensor[i]
|
||||||
|
weight = layer.create_parameter(
|
||||||
|
shape=weight_tensor.shape,
|
||||||
|
dtype=weight_tensor.dtype,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
|
weight.set_value(weight_tensor)
|
||||||
|
weights_list.append(weight)
|
||||||
|
weights_name = self.added_weight_attrs[idx]
|
||||||
|
setattr(layer, weights_name, weights_list)
|
||||||
|
|
||||||
|
def apply_ep_prefill(
|
||||||
|
self,
|
||||||
|
layer: nn.Layer,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
gate_out: paddle.Tensor,
|
||||||
|
) -> paddle.Tensor:
|
||||||
|
"""
|
||||||
|
Apply the EP prefill method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def apply_ep_decode(
|
||||||
|
self,
|
||||||
|
layer: nn.Layer,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
gate_out: paddle.Tensor,
|
||||||
|
) -> paddle.Tensor:
|
||||||
|
"""
|
||||||
|
Apply the EP decoder method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def apply_tp(
|
||||||
|
self,
|
||||||
|
layer: nn.Layer,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
gate: nn.Layer,
|
||||||
|
) -> paddle.Tensor:
|
||||||
|
"""
|
||||||
|
Paddle hpu Fused MoE.
|
||||||
|
"""
|
||||||
|
if layer.topk_method == "noaux_tc":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# norm_topk_prob = False if layer.topk_method == "noaux_tc" else True
|
||||||
|
"""
|
||||||
|
weights = paddle.nn.functional.softmax(gate_out, axis=-1)
|
||||||
|
if layer.moe_use_gate_correction_bias:
|
||||||
|
scores = weights + layer.gate_correction_bias
|
||||||
|
_, selected_experts = paddle.topk(scores, layer.top_k, axis=-1)
|
||||||
|
routing_weights = paddle.index_sample(weights, selected_experts)
|
||||||
|
else:
|
||||||
|
routing_weights, selected_experts = paddle.topk(weights, layer.top_k, axis=-1)
|
||||||
|
routing_weights /= paddle.sum(routing_weights, axis=-1, keepdim=True)
|
||||||
|
|
||||||
|
common_inputs = (x, selected_experts, routing_weights.cast("bfloat16"))
|
||||||
|
|
||||||
|
common_params = (
|
||||||
|
False, #permuted_weights
|
||||||
|
"silu", #activation,
|
||||||
|
0,
|
||||||
|
layer.num_experts - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
weights = (
|
||||||
|
layer.moe_ffn1_weight,
|
||||||
|
layer.moe_ffn2_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
fused_moe_out, _ = mixture_of_experts(
|
||||||
|
*common_inputs, *weights, *common_params, False
|
||||||
|
)
|
||||||
|
|
||||||
|
# if norm_topk_prob:
|
||||||
|
# routing_weights_norm = paddle.sum(routing_weights, axis=-1, keepdim=True).cast("bfloat16")
|
||||||
|
# fused_moe_out = fused_moe_out / routing_weights_norm
|
||||||
|
"""
|
||||||
|
chunk_size = 64
|
||||||
|
from fastdeploy.model_executor.ops.intel_hpu import fused_gate_moe
|
||||||
|
|
||||||
|
# TODO: fuse matmul to gate_moe
|
||||||
|
gate_out = paddle.matmul(x.cast("float32"), gate.weight)
|
||||||
|
fused_moe_out = fused_gate_moe(
|
||||||
|
x,
|
||||||
|
gate_out,
|
||||||
|
layer.gate_correction_bias,
|
||||||
|
layer.up_gate_proj_weight,
|
||||||
|
layer.down_proj_weight,
|
||||||
|
layer.top_k,
|
||||||
|
layer.moe_use_gate_correction_bias,
|
||||||
|
norm_topk_prob=True,
|
||||||
|
permuted_weights=False,
|
||||||
|
activation="silu",
|
||||||
|
experts_min=layer.expert_id_offset,
|
||||||
|
experts_max=layer.expert_id_offset + layer.num_local_experts - 1,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
)
|
||||||
|
if layer.reduce_results and layer.tp_size > 1:
|
||||||
|
tensor_model_parallel_all_reduce_custom(fused_moe_out)
|
||||||
|
|
||||||
|
return fused_moe_out
|
||||||
|
|
||||||
|
|
||||||
|
class HpuTensorWiseFP8MoEMethod(HpuMoEMethod):
|
||||||
|
"""
|
||||||
|
Use Cutlass Group Gemm to compute Fused MoE.
|
||||||
|
This method is the oldest way to compute MoE in Paddle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
|
||||||
|
# TODO: split create_parameter from process_loaded_weights
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
def process_loaded_weights(self, layer: nn.Layer, state_dict):
|
||||||
|
"""
|
||||||
|
Paddle HPU load weight process.
|
||||||
|
"""
|
||||||
|
# bf16
|
||||||
|
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.ops.intel_hpu import fused_quant
|
||||||
|
|
||||||
|
self.quant_fn = fused_quant
|
||||||
|
self.moe_quant_type = "tensor_wise_fp8"
|
||||||
|
|
||||||
|
for idx, weights_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
|
||||||
|
weights_name = self.added_weight_attrs[idx]
|
||||||
|
scales_name = self.added_scale_attrs[idx]
|
||||||
|
|
||||||
|
weights_list = []
|
||||||
|
scales_list = []
|
||||||
|
|
||||||
|
for i in range(layer.num_local_experts):
|
||||||
|
# quantize loaded weights
|
||||||
|
quant_weight, scale = self.quant_fn(weights_tensor[i])
|
||||||
|
weights_list.append(quant_weight)
|
||||||
|
scales_list.append(scale)
|
||||||
|
|
||||||
|
setattr(layer, weights_name, weights_list)
|
||||||
|
setattr(layer, scales_name, scales_list)
|
||||||
|
|
||||||
|
def apply_ep_prefill(
|
||||||
|
self,
|
||||||
|
layer: nn.Layer,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
gate_out: paddle.Tensor,
|
||||||
|
) -> paddle.Tensor:
|
||||||
|
"""
|
||||||
|
Apply the EP prefill method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def apply_ep_decode(
|
||||||
|
self,
|
||||||
|
layer: nn.Layer,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
gate_out: paddle.Tensor,
|
||||||
|
) -> paddle.Tensor:
|
||||||
|
"""
|
||||||
|
Apply the EP decoder method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def apply_tp(
|
||||||
|
self,
|
||||||
|
layer: nn.Layer,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
gate: nn.Layer,
|
||||||
|
) -> paddle.Tensor:
|
||||||
|
"""
|
||||||
|
Paddle hpu Fused MoE.
|
||||||
|
"""
|
||||||
|
if layer.topk_method == "noaux_tc":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# norm_topk_prob = False if layer.topk_method == "noaux_tc" else True
|
||||||
|
|
||||||
|
chunk_size = 64
|
||||||
|
from fastdeploy.model_executor.ops.intel_hpu import fused_gate_moe_fp8
|
||||||
|
|
||||||
|
# TODO: fuse matmul to gate_moe
|
||||||
|
gate_out = paddle.matmul(x.cast("float32"), gate.weight)
|
||||||
|
fused_moe_out = fused_gate_moe_fp8(
|
||||||
|
x,
|
||||||
|
gate_out,
|
||||||
|
layer.gate_correction_bias,
|
||||||
|
layer.up_gate_proj_weight,
|
||||||
|
layer.down_proj_weight,
|
||||||
|
None, # intermediate_hidden_states_scales
|
||||||
|
layer.up_gate_proj_weight_scale,
|
||||||
|
layer.down_proj_weight_scale,
|
||||||
|
layer.top_k,
|
||||||
|
layer.moe_use_gate_correction_bias,
|
||||||
|
norm_topk_prob=True,
|
||||||
|
permuted_weights=False,
|
||||||
|
activation="silu",
|
||||||
|
experts_min=layer.expert_id_offset,
|
||||||
|
experts_max=layer.expert_id_offset + layer.num_local_experts - 1,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if layer.reduce_results and layer.tp_size > 1:
|
||||||
|
tensor_model_parallel_all_reduce_custom(fused_moe_out)
|
||||||
|
|
||||||
|
return fused_moe_out
|
@@ -116,6 +116,7 @@ class LinearBase(nn.Layer):
|
|||||||
or current_platform.is_gcu()
|
or current_platform.is_gcu()
|
||||||
or current_platform.is_dcu()
|
or current_platform.is_dcu()
|
||||||
or current_platform.is_maca()
|
or current_platform.is_maca()
|
||||||
|
or current_platform.is_intel_hpu()
|
||||||
):
|
):
|
||||||
self.forward = self.forward_cuda
|
self.forward = self.forward_cuda
|
||||||
else:
|
else:
|
||||||
|
@@ -56,6 +56,11 @@ def get_moe_method():
|
|||||||
)
|
)
|
||||||
|
|
||||||
return MetaxTritonWeightOnlyMoEMethod(None)
|
return MetaxTritonWeightOnlyMoEMethod(None)
|
||||||
|
elif current_platform.is_intel_hpu():
|
||||||
|
from fastdeploy.model_executor.layers.backends import HpuMoEMethod
|
||||||
|
|
||||||
|
return HpuMoEMethod(None)
|
||||||
|
# return HpuTensorWiseFP8MoEMethod(None)
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@@ -139,6 +144,7 @@ class FusedMoE(nn.Layer):
|
|||||||
|
|
||||||
self.hidden_size = fd_config.model_config.hidden_size
|
self.hidden_size = fd_config.model_config.hidden_size
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
|
|
||||||
self.num_local_experts = self.num_experts // self.ep_size
|
self.num_local_experts = self.num_experts // self.ep_size
|
||||||
|
|
||||||
self.moe_intermediate_size = moe_intermediate_size // self.tp_size
|
self.moe_intermediate_size = moe_intermediate_size // self.tp_size
|
||||||
|
@@ -69,6 +69,12 @@ class ErnieRotaryEmbedding:
|
|||||||
.transpose([0, 1, 2, 4, 3])
|
.transpose([0, 1, 2, 4, 3])
|
||||||
.reshape([2, bsz, max_seq_len, 1, self.rotary_dim])
|
.reshape([2, bsz, max_seq_len, 1, self.rotary_dim])
|
||||||
)
|
)
|
||||||
|
if paddle.is_compiled_with_custom_device("intel_hpu"):
|
||||||
|
return (
|
||||||
|
paddle.concat([rot_emb, rot_emb], axis=3)
|
||||||
|
.transpose([0, 1, 2, 4, 3])
|
||||||
|
.reshape([2, bsz, max_seq_len, 1, self.rotary_dim])
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return rot_emb
|
return rot_emb
|
||||||
|
|
||||||
|
@@ -54,3 +54,6 @@ class SamplingMetadata:
|
|||||||
temp_scaled_logprobs: Optional[paddle.Tensor] = None
|
temp_scaled_logprobs: Optional[paddle.Tensor] = None
|
||||||
top_p_normalized_logprobs: Optional[paddle.Tensor] = None
|
top_p_normalized_logprobs: Optional[paddle.Tensor] = None
|
||||||
share_inputs: Optional[Dict[str, paddle.Tensor]] = None
|
share_inputs: Optional[Dict[str, paddle.Tensor]] = None
|
||||||
|
# Add for HPU post-processing
|
||||||
|
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||||
|
seq_lens_decoder: Optional[paddle.Tensor] = None
|
||||||
|
@@ -136,6 +136,23 @@ def apply_penalty_multi_scores(
|
|||||||
min_dec_lens,
|
min_dec_lens,
|
||||||
eos_token_ids,
|
eos_token_ids,
|
||||||
)
|
)
|
||||||
|
elif current_platform.is_intel_hpu():
|
||||||
|
from fastdeploy.model_executor.ops.intel_hpu import (
|
||||||
|
get_token_penalty_multi_scores,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = get_token_penalty_multi_scores(
|
||||||
|
pre_token_ids,
|
||||||
|
logits,
|
||||||
|
repetition_penalties,
|
||||||
|
frequency_penalties,
|
||||||
|
presence_penalties,
|
||||||
|
temperature,
|
||||||
|
bad_words_token_ids,
|
||||||
|
step_idx,
|
||||||
|
min_dec_lens,
|
||||||
|
eos_token_ids,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@@ -209,6 +209,8 @@ class Sampler(nn.Layer):
|
|||||||
or current_platform.is_maca()
|
or current_platform.is_maca()
|
||||||
):
|
):
|
||||||
self.forward = self.forward_cuda
|
self.forward = self.forward_cuda
|
||||||
|
elif current_platform.is_intel_hpu():
|
||||||
|
self.forward = self.forward_intel_hpu
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -377,6 +379,49 @@ class Sampler(nn.Layer):
|
|||||||
|
|
||||||
return sampler_output
|
return sampler_output
|
||||||
|
|
||||||
|
def forward_intel_hpu(
|
||||||
|
self,
|
||||||
|
logits: paddle.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
batch_ids: paddle.Tensor,
|
||||||
|
max_batch: int,
|
||||||
|
rank: int,
|
||||||
|
local_rank: int,
|
||||||
|
) -> paddle.Tensor:
|
||||||
|
if logits.dtype != paddle.float32:
|
||||||
|
logits = paddle.cast(logits, paddle.float32)
|
||||||
|
|
||||||
|
from fastdeploy.model_executor.ops.intel_hpu import fused_sampler
|
||||||
|
|
||||||
|
_, next_tokens = fused_sampler(
|
||||||
|
sampling_metadata.pre_token_ids,
|
||||||
|
sampling_metadata.prompt_ids,
|
||||||
|
sampling_metadata.seq_lens_encoder,
|
||||||
|
sampling_metadata.seq_lens_decoder,
|
||||||
|
sampling_metadata.step_idx,
|
||||||
|
sampling_metadata.stop_flags,
|
||||||
|
logits,
|
||||||
|
sampling_metadata.repetition_penalties,
|
||||||
|
sampling_metadata.frequency_penalties,
|
||||||
|
sampling_metadata.presence_penalties,
|
||||||
|
sampling_metadata.temperature,
|
||||||
|
sampling_metadata.bad_words_token_ids,
|
||||||
|
sampling_metadata.step_idx,
|
||||||
|
sampling_metadata.min_dec_lens,
|
||||||
|
sampling_metadata.eos_token_ids,
|
||||||
|
sampling_metadata.top_p,
|
||||||
|
rank,
|
||||||
|
local_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
if next_tokens.shape[0] != max_batch:
|
||||||
|
dim = next_tokens.shape[-1]
|
||||||
|
tmp_tokens = paddle.full((max_batch, dim), -1, dtype=next_tokens.dtype)
|
||||||
|
tmp_tokens = paddle.scatter(tmp_tokens, batch_ids, next_tokens[: batch_ids.shape[0], :])
|
||||||
|
return tmp_tokens
|
||||||
|
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
|
||||||
class SpeculativeSampler(nn.Layer):
|
class SpeculativeSampler(nn.Layer):
|
||||||
"""
|
"""
|
||||||
|
@@ -12,6 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""fastdeploy module"""
|
"""fastdeploy module"""
|
||||||
from . import cpu, gcu, gpu, iluvatar, npu, xpu
|
from . import cpu, gcu, gpu, iluvatar, intel_hpu, npu, xpu
|
||||||
|
|
||||||
__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar", "gcu"]
|
__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar", "gcu", "intel_hpu"]
|
||||||
|
21
fastdeploy/model_executor/ops/intel_hpu/__init__.py
Normal file
21
fastdeploy/model_executor/ops/intel_hpu/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""fastdeploy intel_hpu ops."""
|
||||||
|
|
||||||
|
from fastdeploy.import_ops import import_custom_ops
|
||||||
|
|
||||||
|
# PACKAGE = "fastdeploy.model_executor.ops.intel_hpu"
|
||||||
|
PACKAGE = "paddlenlp_ops"
|
||||||
|
|
||||||
|
import_custom_ops(PACKAGE, "paddlenlp_ops", globals())
|
@@ -56,6 +56,8 @@ elif current_platform.is_maca():
|
|||||||
update_inputs,
|
update_inputs,
|
||||||
update_inputs_v1,
|
update_inputs_v1,
|
||||||
)
|
)
|
||||||
|
elif current_platform.is_intel_hpu():
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
get_padding_offset,
|
get_padding_offset,
|
||||||
|
@@ -284,6 +284,8 @@ class TokenProcessor:
|
|||||||
from fastdeploy.model_executor.ops.iluvatar import get_output
|
from fastdeploy.model_executor.ops.iluvatar import get_output
|
||||||
elif current_platform.is_gcu():
|
elif current_platform.is_gcu():
|
||||||
from fastdeploy.model_executor.ops.gcu import get_output
|
from fastdeploy.model_executor.ops.gcu import get_output
|
||||||
|
elif current_platform.is_intel_hpu():
|
||||||
|
from fastdeploy.model_executor.ops.intel_hpu import get_output
|
||||||
else:
|
else:
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
get_output,
|
get_output,
|
||||||
|
@@ -23,6 +23,7 @@ from .cuda import CUDAPlatform
|
|||||||
from .dcu import DCUPlatform
|
from .dcu import DCUPlatform
|
||||||
from .gcu import GCUPlatform
|
from .gcu import GCUPlatform
|
||||||
from .iluvatar import IluvatarPlatform
|
from .iluvatar import IluvatarPlatform
|
||||||
|
from .intel_hpu import INTEL_HPUPlatform
|
||||||
from .maca import MACAPlatform
|
from .maca import MACAPlatform
|
||||||
from .npu import NPUPlatform
|
from .npu import NPUPlatform
|
||||||
from .xpu import XPUPlatform
|
from .xpu import XPUPlatform
|
||||||
@@ -43,6 +44,8 @@ def __getattr__(name: str):
|
|||||||
_current_platform = XPUPlatform()
|
_current_platform = XPUPlatform()
|
||||||
elif paddle.is_compiled_with_custom_device("npu"):
|
elif paddle.is_compiled_with_custom_device("npu"):
|
||||||
_current_platform = NPUPlatform()
|
_current_platform = NPUPlatform()
|
||||||
|
elif paddle.is_compiled_with_custom_device("intel_hpu"):
|
||||||
|
_current_platform = INTEL_HPUPlatform()
|
||||||
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
|
||||||
_current_platform = IluvatarPlatform()
|
_current_platform = IluvatarPlatform()
|
||||||
elif paddle.is_compiled_with_custom_device("gcu"):
|
elif paddle.is_compiled_with_custom_device("gcu"):
|
||||||
|
@@ -27,6 +27,7 @@ class _Backend(enum.Enum):
|
|||||||
FLASH_ATTN = enum.auto()
|
FLASH_ATTN = enum.auto()
|
||||||
BLOCK_ATTN = enum.auto()
|
BLOCK_ATTN = enum.auto()
|
||||||
PLAS_ATTN = enum.auto()
|
PLAS_ATTN = enum.auto()
|
||||||
|
HPU_ATTN = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
class Platform:
|
class Platform:
|
||||||
@@ -54,6 +55,12 @@ class Platform:
|
|||||||
"""
|
"""
|
||||||
return paddle.is_compiled_with_xpu()
|
return paddle.is_compiled_with_xpu()
|
||||||
|
|
||||||
|
def is_intel_hpu(self) -> bool:
|
||||||
|
"""
|
||||||
|
whether platform is intel_hpu
|
||||||
|
"""
|
||||||
|
return paddle.is_compiled_with_custom_device("intel_hpu")
|
||||||
|
|
||||||
def is_cpu(self) -> bool:
|
def is_cpu(self) -> bool:
|
||||||
"""
|
"""
|
||||||
whether platform is cpu
|
whether platform is cpu
|
||||||
|
52
fastdeploy/platforms/intel_hpu.py
Normal file
52
fastdeploy/platforms/intel_hpu.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from fastdeploy.utils import console_logger as logger
|
||||||
|
|
||||||
|
from .base import Platform, _Backend
|
||||||
|
|
||||||
|
|
||||||
|
class INTEL_HPUPlatform(Platform):
|
||||||
|
device_name = "intel_hpu"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def available(self):
|
||||||
|
"""
|
||||||
|
Check whether Intel HPU is available.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
assert paddle.base.core.get_custom_device_count("intel_hpu") > 0
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"You are using Intel HPU platform, but there is no Intel HPU "
|
||||||
|
"detected on your machine. Maybe Intel HPU devices is not set properly."
|
||||||
|
f"\n Original Error is {e}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_attention_backend_cls(cls, selected_backend):
|
||||||
|
"""
|
||||||
|
get_attention_backend_cls
|
||||||
|
"""
|
||||||
|
if selected_backend == _Backend.NATIVE_ATTN:
|
||||||
|
logger.info("Using NATIVE ATTN backend.")
|
||||||
|
return "fastdeploy.model_executor.layers.attention.PaddleNativeAttnBackend"
|
||||||
|
elif selected_backend == _Backend.HPU_ATTN:
|
||||||
|
logger.info("Using HPU ATTN backend.")
|
||||||
|
return "fastdeploy.model_executor.layers.backends.intel_hpu.attention.HPUAttentionBackend"
|
||||||
|
else:
|
||||||
|
logger.warning("Other backends are not supported for now.")
|
1463
fastdeploy/worker/hpu_model_runner.py
Normal file
1463
fastdeploy/worker/hpu_model_runner.py
Normal file
File diff suppressed because it is too large
Load Diff
213
fastdeploy/worker/hpu_worker.py
Normal file
213
fastdeploy/worker/hpu_worker.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
from paddle.base import core
|
||||||
|
|
||||||
|
from fastdeploy.config import FDConfig
|
||||||
|
from fastdeploy.engine.request import Request
|
||||||
|
from fastdeploy.utils import get_logger, set_random_seed
|
||||||
|
from fastdeploy.worker.hpu_model_runner import HPUModelRunner
|
||||||
|
from fastdeploy.worker.output import ModelRunnerOutput
|
||||||
|
from fastdeploy.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
|
logger = get_logger("hpu_worker", "hpu_worker.log")
|
||||||
|
|
||||||
|
|
||||||
|
def max_memory_allocated(device_id: int) -> int:
|
||||||
|
return core.device_memory_stat_peak_value("Allocated", device_id)
|
||||||
|
|
||||||
|
|
||||||
|
def max_memory_reserved(device_id: int) -> int:
|
||||||
|
return core.device_memory_stat_peak_value("Reserved", device_id)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_max_memory_allocated(device_id: int) -> None:
|
||||||
|
core.device_memory_stat_reset_peak_value("Allocated", device_id)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_max_memory_reserved(device_id: int) -> None:
|
||||||
|
core.device_memory_stat_reset_peak_value("Reserved", device_id)
|
||||||
|
|
||||||
|
|
||||||
|
class HpuWorker(WorkerBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fd_config: FDConfig,
|
||||||
|
local_rank: int,
|
||||||
|
rank: int,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
fd_config=fd_config,
|
||||||
|
local_rank=local_rank,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def init_device(self):
|
||||||
|
"""
|
||||||
|
Initialize device and construct model runner
|
||||||
|
"""
|
||||||
|
if paddle.is_compiled_with_custom_device("intel_hpu"):
|
||||||
|
# Set environment variable
|
||||||
|
self.device_ids = self.parallel_config.device_ids.split(",")
|
||||||
|
logger.info(
|
||||||
|
f"Using Intel HPU device with local rank => device id: {int(self.device_ids[self.local_rank])} as module id"
|
||||||
|
)
|
||||||
|
intel_hpus_module_id = int(self.device_ids[self.local_rank])
|
||||||
|
self.device = f"intel_hpu:{intel_hpus_module_id}"
|
||||||
|
paddle.device.set_device(self.device)
|
||||||
|
paddle.set_default_dtype(self.parallel_config.dtype)
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
paddle.device.cuda.empty_cache()
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Not support device type: {self.device_config.device}")
|
||||||
|
|
||||||
|
set_random_seed(self.fd_config.model_config.seed)
|
||||||
|
# Construct model runner
|
||||||
|
self.model_runner: HPUModelRunner = HPUModelRunner(
|
||||||
|
fd_config=self.fd_config,
|
||||||
|
device=self.device,
|
||||||
|
device_id=self.device_ids[self.local_rank],
|
||||||
|
rank=self.rank,
|
||||||
|
local_rank=self.local_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
def exist_prefill(self):
|
||||||
|
"""
|
||||||
|
check whether prefill stage exist
|
||||||
|
"""
|
||||||
|
return self.model_runner.exist_prefill()
|
||||||
|
|
||||||
|
def determine_available_memory(self) -> int:
|
||||||
|
"""
|
||||||
|
Profiles the peak memory usage of the model to determine how much
|
||||||
|
memory can be used for KV cache without OOMs.
|
||||||
|
|
||||||
|
The engine will first conduct a profiling of the existing memory usage.
|
||||||
|
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||||
|
that can be allocated with the remaining free memory.
|
||||||
|
|
||||||
|
Tip:
|
||||||
|
You may limit the usage of GPU memory
|
||||||
|
by adjusting the `gpu_memory_utilization` parameter.
|
||||||
|
"""
|
||||||
|
# 1. Record memory state before profile run
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
module_id = int(self.device_ids[self.local_rank])
|
||||||
|
reset_max_memory_allocated(module_id)
|
||||||
|
reset_max_memory_reserved(module_id)
|
||||||
|
paddle_reserved_mem_before_run = max_memory_reserved(module_id)
|
||||||
|
paddle_allocated_mem_before_run = max_memory_allocated(module_id) # not reserved
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
(
|
||||||
|
"Before running the profile, the memory usage info is as follows:",
|
||||||
|
f"\nPaddle reserved memory: {paddle_reserved_mem_before_run}",
|
||||||
|
f"\nPaddle allocated memory: {paddle_allocated_mem_before_run}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Profile run
|
||||||
|
self.model_runner.profile_run()
|
||||||
|
|
||||||
|
# 3. Statistical memory information
|
||||||
|
paddle_reserved_mem_after_run = max_memory_reserved(module_id)
|
||||||
|
paddle_allocated_mem_after_run = max_memory_allocated(module_id)
|
||||||
|
|
||||||
|
one_mb = 1024 * 1024
|
||||||
|
one_gb = 1024 * one_mb
|
||||||
|
hpu_reserved_memory = 768 * one_mb # 768MB reserved for not paddle use memory
|
||||||
|
hpu_total_memory = 96 * one_gb # 96GB HPU memory
|
||||||
|
peak_memory = paddle_allocated_mem_after_run + hpu_reserved_memory
|
||||||
|
available_kv_cache_memory = hpu_total_memory * self.cache_config.gpu_memory_utilization - peak_memory
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
(
|
||||||
|
"After running the profile, the memory usage info is as follows:",
|
||||||
|
f"\nPaddle reserved memory: {paddle_reserved_mem_after_run}",
|
||||||
|
f"\nPaddle allocated memory: {paddle_allocated_mem_after_run}",
|
||||||
|
f"\nAvailable KV Cache meomory: {available_kv_cache_memory}",
|
||||||
|
f"Profile time: {end_time - start_time}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return available_kv_cache_memory # return to caculate the block num in this device
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
"""Load model"""
|
||||||
|
self.model_runner.load_model()
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Layer:
|
||||||
|
"""Get current model"""
|
||||||
|
return self.model_runner.get_model()
|
||||||
|
|
||||||
|
def initialize_cache(self, num_gpu_blocks: int) -> None:
|
||||||
|
"""Initialize the KV Cache with accurate num_gpu_blocks"""
|
||||||
|
# accurate cache size
|
||||||
|
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
|
||||||
|
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
model_forward_batch: Optional[List[Request]] = None,
|
||||||
|
num_running_request: int = None,
|
||||||
|
) -> Optional[ModelRunnerOutput]:
|
||||||
|
""" """
|
||||||
|
output = self.model_runner.execute_model(model_forward_batch)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int) -> None:
|
||||||
|
"""Process new requests and then start the decode loop
|
||||||
|
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
|
||||||
|
and workers and modelrunners should not perceive it.
|
||||||
|
"""
|
||||||
|
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts, num_running_requests=num_running_requests)
|
||||||
|
|
||||||
|
def graph_optimize_and_warm_up_model(self) -> None:
|
||||||
|
"""
|
||||||
|
Perform the warm-up and the graph optimization
|
||||||
|
"""
|
||||||
|
# wait for all cards loading model completely.
|
||||||
|
if self.rank > 1:
|
||||||
|
paddle.distributed.barrier()
|
||||||
|
# 1. Warm up model
|
||||||
|
# NOTE(gongshaotian): may be not need warm_up at this place
|
||||||
|
if int(os.environ.get("HPU_WARMUP_BUCKET", 0)) == 1:
|
||||||
|
logger.info("Warmup bucket is enabled, start warmup bucket")
|
||||||
|
self.model_runner.is_warmuping = True
|
||||||
|
self.model_runner.warm_up_bucket()
|
||||||
|
self.model_runner.is_warmuping = False
|
||||||
|
else:
|
||||||
|
logger.info("Skipping warmup bucket, please set HPU_WARMUP_BUCKET=1 to enable it.")
|
||||||
|
|
||||||
|
# 2. Triger cuda grpah capture
|
||||||
|
self.model_runner.capture_model()
|
||||||
|
|
||||||
|
def check_health(self) -> bool:
|
||||||
|
""" """
|
||||||
|
return True
|
||||||
|
|
||||||
|
def cal_theortical_kvcache(self) -> int:
|
||||||
|
"""Calculate the block memory required"""
|
||||||
|
return self.model_runner.cal_theortical_kvcache()
|
@@ -82,6 +82,10 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
|
|||||||
from fastdeploy.worker.metax_worker import MetaxWorker
|
from fastdeploy.worker.metax_worker import MetaxWorker
|
||||||
|
|
||||||
return MetaxWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
|
return MetaxWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
|
||||||
|
if current_platform.is_intel_hpu():
|
||||||
|
from fastdeploy.worker.hpu_worker import HpuWorker
|
||||||
|
|
||||||
|
return HpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
|
||||||
|
|
||||||
|
|
||||||
def init_distributed_environment(seed: int = 20) -> Tuple[int, int]:
|
def init_distributed_environment(seed: int = 20) -> Tuple[int, int]:
|
||||||
@@ -89,21 +93,22 @@ def init_distributed_environment(seed: int = 20) -> Tuple[int, int]:
|
|||||||
# Global rank
|
# Global rank
|
||||||
ranks = dist.get_world_size()
|
ranks = dist.get_world_size()
|
||||||
dist_strategy = fleet.DistributedStrategy()
|
dist_strategy = fleet.DistributedStrategy()
|
||||||
|
if ranks > 0:
|
||||||
|
dist_strategy.hybrid_configs = {
|
||||||
|
"dp_degree": 1,
|
||||||
|
"mp_degree": ranks,
|
||||||
|
"pp_degree": 1,
|
||||||
|
"sharding_degree": 1,
|
||||||
|
}
|
||||||
|
|
||||||
dist_strategy.hybrid_configs = {
|
# Set control in tensor parallel
|
||||||
"dp_degree": 1,
|
dist_strategy.tensor_parallel_configs = {"tensor_init_seed": seed}
|
||||||
"mp_degree": ranks,
|
fleet.init(is_collective=True, strategy=dist_strategy)
|
||||||
"pp_degree": 1,
|
|
||||||
"sharding_degree": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Set control in tensor parallel
|
|
||||||
dist_strategy.tensor_parallel_configs = {"tensor_init_seed": seed}
|
|
||||||
fleet.init(is_collective=True, strategy=dist_strategy)
|
|
||||||
|
|
||||||
# Local rank
|
|
||||||
local_rank = fleet.worker_index()
|
|
||||||
|
|
||||||
|
# Local rank
|
||||||
|
local_rank = fleet.worker_index()
|
||||||
|
else:
|
||||||
|
local_rank = 0
|
||||||
return ranks, local_rank
|
return ranks, local_rank
|
||||||
|
|
||||||
|
|
||||||
|
2
setup.py
2
setup.py
@@ -174,6 +174,8 @@ def get_device_type():
|
|||||||
return "gcu"
|
return "gcu"
|
||||||
elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||||
return "metax-gpu"
|
return "metax-gpu"
|
||||||
|
elif paddle.is_compiled_with_custom_device("intel_hpu"):
|
||||||
|
return "intel-hpu"
|
||||||
else:
|
else:
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user