[Intel HPU] Support intel hpu platform (#4161)

* [Intel HPU] Support intel hpu platform

* fix some issues

* apply precommit and move AttentionBackend_HPU

* fix format issue

* correct ops import

* fix ci issue

* update code in layers

* fix code style issue

* remove dense tp moe ep mode

* fix enc_dec_block_num

* fix rebase issue

* rename hpu to gaudi in readme

* rename ForwardMeta_HPU to HPUForwardMeta
This commit is contained in:
fmiao2372
2025-09-24 12:27:50 +08:00
committed by GitHub
parent a1c5d930bb
commit f1b5392e20
35 changed files with 2814 additions and 19 deletions

View File

@@ -43,7 +43,7 @@ English | [简体中文](README_CN.md)
- 🤝 **OpenAI API Server and vLLM Compatible**: One-command deployment with [vLLM](https://github.com/vllm-project/vllm/) interface compatibility.
- 🧮 **Comprehensive Quantization Format Support**: W8A16, W8A8, W4A16, W4A8, W2A16, FP8, and more.
-**Advanced Acceleration Techniques**: Speculative decoding, Multi-Token Prediction (MTP) and Chunked Prefill.
- 🖥️ **Multi-Hardware Support**: NVIDIA GPU, Kunlunxin XPU, Hygon DCU, Ascend NPU, Iluvatar GPU, Enflame GCU, MetaX GPU etc.
- 🖥️ **Multi-Hardware Support**: NVIDIA GPU, Kunlunxin XPU, Hygon DCU, Ascend NPU, Iluvatar GPU, Enflame GCU, MetaX GPU, Intel Gaudi etc.
## Requirements
@@ -60,6 +60,7 @@ FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**,
- [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md)
- [Hygon DCU](./docs/get_started/installation/hygon_dcu.md)
- [MetaX GPU](./docs/get_started/installation/metax_gpu.md)
- [Intel Gaudi](./docs/get_started/installation/intel_gaudi.md)
**Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU are currently under development and testing. Stay tuned for updates!

View File

@@ -41,7 +41,7 @@
- 🤝 **OpenAI API服务与vLLM兼容**:单命令部署,兼容[vLLM](https://github.com/vllm-project/vllm/)接口
- 🧮 **全量化格式支持**W8A16、W8A8、W4A16、W4A8、W2A16、FP8等
-**高级加速技术**推测解码、多令牌预测MTP及分块预填充
- 🖥️ **多硬件支持**NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU等
- 🖥️ **多硬件支持**NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU、英特尔Gaudi
## 要求
@@ -58,6 +58,7 @@ FastDeploy 支持在**英伟达NVIDIAGPU**、**昆仑芯KunlunxinXPU
- [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md)
- [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md)
- [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md)
- [英特尔 Gaudi](./docs/zh/get_started/installation/intel_gaudi.md)
**注意:** 我们正在积极拓展硬件支持范围。目前包括昇腾AscendNPU 等其他硬件平台正在开发测试中。敬请关注更新!

View File

@@ -128,6 +128,12 @@ function copy_ops(){
echo -e "MACA ops have been copy to fastdeploy"
return
fi
is_intel_hpu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('intel_hpu'))"`
if [ "$is_intel_hpu" = "True" ]; then
DEVICE_TYPE="intel-hpu"
echo -e "intel_hpu ops have been copy to fastdeploy"
return
fi
DEVICE_TYPE="cpu"
cd ../../../../
@@ -159,7 +165,9 @@ function build_and_install_ops() {
else
FD_BUILDING_ARCS=${FD_BUILDING_ARCS} ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
fi
find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \;
if [ -d "${OPS_TMP_DIR}" ]; then
find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \;
fi
else
echo "Error: Invalid parameter '$FD_CPU_USE_BF16'. Please use true or false."
exit 1

View File

@@ -623,6 +623,8 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
],
),
)
elif paddle.is_compiled_with_custom_device("intel_hpu"):
pass
else:
use_bf16 = envs.FD_CPU_USE_BF16 == "True"

View File

@@ -7,3 +7,4 @@ FastDeploy currently supports installation on the following hardware platforms:
- [Enflame S60 GCU Installation](Enflame_gcu.md)
- [Iluvatar GPU Installation](iluvatar_gpu.md)
- [Hygon DCU Installation](hygon_dcu.md)
- [Intel Gaudi Installation](intel_gaudi.md)

View File

@@ -0,0 +1,75 @@
# Intel Gaudi Installation for running ERNIE 4.5 Series Models
The following installation methods are available when your environment meets these requirements:
- Python 3.10
- Intel Gaudi 2
- Intel Gaudi software version 1.22.0
- Linux X86_64
## 1. Run Docker Container
Use the following commands to run a Docker container. Make sure to update the versions below as listed in the [Support Matrix](https://docs.habana.ai/en/latest/Support_Matrix/Support_Matrix.html):
```{.console}
$ docker pull vault.habana.ai/gaudi-docker/1.22.0/ubuntu22.04/habanalabs/pytorch-installer-2.7.1:latest
$ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.22.0/ubuntu22.04/habanalabs/pytorch-installer-2.7.1:latest
```
### 2. Install PaddlePaddle
```bash
python -m pip install paddlepaddle==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/
```
### 3. Install PaddleCustomDevice
```shell
git clone https://github.com/PaddlePaddle/PaddleCustomDevice
cd PaddleCustomDevice/backends/intel_hpu/
mkdir -p build
cd build
cmake ..
make -j
pip install --force-reinstall dist/paddle_intel_hpu*.whl
cd PaddleCustomDevice/backends/intel_hpu/custom_ops
python setup.py install
```
### 4. Install FastDeploy
```shell
git clone https://github.com/PaddlePaddle/FastDeploy
cd FastDeploy
bash build.sh
```
## Prepare the inference demo
### 1. Start inference service
```shell
export GC_KERNEL_PATH=/usr/lib/habanalabs/libtpc_kernels.so
export GC_KERNEL_PATH=/usr/local/lib/python3.10/dist-packages/paddle_custom_device/intel_hpu/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH
export INTEL_HPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PADDLE_DISTRI_BACKEND=xccl
export PADDLE_XCCL_BACKEND=intel_hpu
export HABANA_PROFILE=0
export HPU_VISIBLE_DEVICES=0
HPU_WARMUP_BUCKET=1 HPU_WARMUP_MODEL_LEN=4096 FD_ATTENTION_BACKEND=HPU_ATTN python -m fastdeploy.entrypoints.openai.api_server --model ERNIE-4.5-21B-A3B-Paddle --tensor-parallel-size 1 --max-model-len 32768 --max-num-seqs 128
```
### 2. Launch the request
```bash
curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "What is AI?"}
], "max_tokens": 24
}'
```
### 3. Successfully returns the result
```json
{"id":"chatcmpl-3bd98ae2-fafe-46ae-a552-d653a8526503","object":"chat.completion","created":1757653575,"model":"ERNIE-4.5-21B-A3B-Paddle","choices":[{"index":0,"message":{"role":"assistant","content":"**AI (Artificial Intelligence)** refers to the development of computer systems that can perform tasks typically requiring human intelligence.","multimodal_content":null,"reasoning_content":null,"tool_calls":null,"prompt_token_ids":null,"completion_token_ids":null,"text_after_process":null,"raw_prediction":null,"prompt_tokens":null,"completion_tokens":null},"logprobs":null,"finish_reason":"length"}],"usage":{"prompt_tokens":11,"total_tokens":35,"completion_tokens":24,"prompt_tokens_details":{"cached_tokens":0}}}
```

View File

@@ -7,3 +7,4 @@ FastDeploy支持如下硬件平台:
- [Enflame S60 GCU Installation](Enflame_gcu.md)
- [Iluvatar GPU Installation](iluvatar_gpu.md)
- [Hygon DCU Installation](hygon_dcu.md)
- [Intel Gaudi Installation](intel_gaudi.md)

View File

@@ -0,0 +1,75 @@
# 使用 Intel Gaudi 运行ERNIE 4.5 系列模型
在环境满足如下条件前提下
- Python 3.10
- Intel Gaudi 2
- Intel Gaudi software version 1.22.0
- Linux X86_64
## 1. 运行Docker容器
使用下面命令运行Docker容器. 确保更新的版本在如下列表中 [Support Matrix](https://docs.habana.ai/en/latest/Support_Matrix/Support_Matrix.html):
```{.console}
$ docker pull vault.habana.ai/gaudi-docker/1.22.0/ubuntu22.04/habanalabs/pytorch-installer-2.7.1:latest
$ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.22.0/ubuntu22.04/habanalabs/pytorch-installer-2.7.1:latest
```
### 2. 安装 PaddlePaddle
```bash
python -m pip install paddlepaddle==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/
```
### 3. 安装 PaddleCustomDevice
```shell
git clone https://github.com/PaddlePaddle/PaddleCustomDevice
cd PaddleCustomDevice/backends/intel_hpu/
mkdir -p build
cd build
cmake ..
make -j
pip install --force-reinstall dist/paddle_intel_hpu*.whl
cd PaddleCustomDevice/backends/intel_hpu/custom_ops
python setup.py install
```
### 4. 安装 FastDeploy
```shell
git clone https://github.com/PaddlePaddle/FastDeploy
cd FastDeploy
bash build.sh
```
## 准备推理示例
### 1. 启动推理服务
```shell
export GC_KERNEL_PATH=/usr/lib/habanalabs/libtpc_kernels.so
export GC_KERNEL_PATH=/usr/local/lib/python3.10/dist-packages/paddle_custom_device/intel_hpu/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH
export INTEL_HPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PADDLE_DISTRI_BACKEND=xccl
export PADDLE_XCCL_BACKEND=intel_hpu
export HABANA_PROFILE=0
export HPU_VISIBLE_DEVICES=0
HPU_WARMUP_BUCKET=1 HPU_WARMUP_MODEL_LEN=4096 FD_ATTENTION_BACKEND=HPU_ATTN python -m fastdeploy.entrypoints.openai.api_server --model ERNIE-4.5-21B-A3B-Paddle --tensor-parallel-size 1 --max-model-len 32768 --max-num-seqs 128
```
### 2. 发送请求
```bash
curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "What is AI?"}
], "max_tokens": 24
}'
```
### 3. 成功返回结果
```json
{"id":"chatcmpl-3bd98ae2-fafe-46ae-a552-d653a8526503","object":"chat.completion","created":1757653575,"model":"ERNIE-4.5-21B-A3B-Paddle","choices":[{"index":0,"message":{"role":"assistant","content":"**AI (Artificial Intelligence)** refers to the development of computer systems that can perform tasks typically requiring human intelligence.","multimodal_content":null,"reasoning_content":null,"tool_calls":null,"prompt_token_ids":null,"completion_token_ids":null,"text_after_process":null,"raw_prediction":null,"prompt_tokens":null,"completion_tokens":null},"logprobs":null,"finish_reason":"length"}],"usage":{"prompt_tokens":11,"total_tokens":35,"completion_tokens":24,"prompt_tokens_details":{"cached_tokens":0}}}
```

View File

@@ -1498,6 +1498,8 @@ class FDConfig:
self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids)
if current_platform.is_xpu():
self.device_ids = os.getenv("XPU_VISIBLE_DEVICES", self.device_ids)
if current_platform.is_intel_hpu():
self.device_ids = os.getenv("HPU_VISIBLE_DEVICES", self.device_ids)
self.read_from_config()
self.postprocess()

View File

@@ -66,3 +66,26 @@ try:
except:
tensor_model_parallel_all_reduce = None
from paddle.distributed.communication import stream
from paddle.distributed.communication.reduce import ReduceOp
def all_reduce(
tensor,
op,
group,
sync_op: bool = True,
):
return stream.all_reduce(tensor, op=op, group=group, sync_op=sync_op, use_calc_stream=True)
@paddle.jit.marker.unified
def tensor_model_parallel_all_reduce_custom(input_: paddle.Tensor) -> paddle.Tensor:
"""All-reduce the input tensor across model parallel group on calc stream."""
if paddle.in_dynamic_mode():
hcg = dist.fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group()
all_reduce(input_, op=ReduceOp.SUM, group=mp_group)
else:
dist.all_reduce(input_)

View File

@@ -17,12 +17,14 @@
import logging
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import Optional
from typing import TYPE_CHECKING, Dict, Optional
import paddle
from fastdeploy.model_executor.layers.attention import AttentionBackend
if TYPE_CHECKING:
from fastdeploy.model_executor.layers.attention import AttentionBackend_HPU
logger = logging.getLogger(__name__)
@@ -240,3 +242,116 @@ class DCUForwardMeta(ForwardMeta):
# Accumulated offset
cum_offsets: Optional[paddle.Tensor] = None
@dataclass
class HPUForwardMeta:
"""
HPUForwardMeta is used to store the global meta information of the forward on intel HPU.
"""
#
input_ids: paddle.Tensor
# attention meta
forward_mode: ForwardMode = ForwardMode.MIXED
#
ids_remove_padding: paddle.Tensor = None
#
seq_lens_encoder: Optional[paddle.Tensor] = None
#
seq_lens_decoder: Optional[paddle.Tensor] = None
#
seq_lens_this_time: Optional[paddle.Tensor] = None
#
cum_offsets: Optional[paddle.Tensor] = None
#
block_tables: Optional[paddle.Tensor] = None
#
block_groups: Optional[paddle.Tensor] = None
#
block_list: Optional[paddle.Tensor] = None
#
block_indices: Optional[paddle.Tensor] = None
#
block_offsets: Optional[paddle.Tensor] = None
#
block_mapping: Optional[paddle.Tensor] = None
#
attention_mask: Optional[paddle.Tensor] = None
#
block_size: Optional[paddle.Tensor] = None
#
batch_ids: Optional[paddle.Tensor] = None
#
total_batch: Optional[paddle.Tensor] = None
#
is_prompt: Optional[paddle.Tensor] = None
#
attn_backend: "AttentionBackend_HPU" = None
#
rotary_embs: Optional[paddle.Tensor] = None
#
caches: Optional[paddle.Tensor] = None
#
attn_mask: Optional[paddle.Tensor] = None
#
pre_caches_length: int = 0
@classmethod
def init_forward_meta(cls, share_inputs: Dict, attn_backend: "AttentionBackend_HPU"):
"""init forward meta"""
# TODO(gongshaotian): delete this func
is_prompt = share_inputs["is_prompt"]
forward_mode = ForwardMode.DECODE
if is_prompt:
forward_mode = ForwardMode.EXTEND
ret = cls(
forward_mode=forward_mode,
input_ids=share_inputs["input_ids"],
ids_remove_padding=share_inputs["ids_remove_padding"],
seq_lens_encoder=share_inputs["seq_lens_encoder"],
seq_lens_decoder=share_inputs["seq_lens_decoder"],
seq_lens_this_time=share_inputs["seq_lens_this_time"],
block_tables=share_inputs["block_tables"],
block_groups=share_inputs["block_groups"],
block_list=share_inputs["block_list"],
block_indices=share_inputs["block_indices"],
block_offsets=share_inputs["block_offsets"],
block_mapping=share_inputs["block_mapping"],
attention_mask=share_inputs["block_bias"],
block_size=share_inputs["block_size"],
total_batch=share_inputs["total_batch"],
batch_ids=share_inputs["batch_ids"],
is_prompt=share_inputs["is_prompt"],
attn_backend=attn_backend,
rotary_embs=share_inputs["rotary_embs"],
caches=share_inputs["caches"],
)
return ret
def clear_caches(self):
"""safe clear caches"""
if self.caches:
del self.caches

View File

@@ -72,6 +72,8 @@ class SiluAndMul(nn.Layer):
self.forward = self.forward_cuda
elif current_platform.is_gcu():
self.forward = self.forward_gcu
elif current_platform.is_intel_hpu():
self.forward = self.forward_intel_hpu
else:
raise NotImplementedError
@@ -147,6 +149,16 @@ class SiluAndMul(nn.Layer):
out = out + self.bias
return out
def forward_intel_hpu(self, x):
"""
Forward propagation of the custom activation layer.
Args:
x (Tensor): Input tensor to the activation layer.
Returns:
Tensor: Output tensor.
"""
return
def get_act_fn(act_fn_name: str) -> nn.Layer:
"""Get an activation function by name."""

View File

@@ -55,3 +55,10 @@ if current_platform.is_maca():
if hasattr(metax, "__all__"):
globals().update({name: getattr(metax, name) for name in metax.__all__})
__all__.extend(metax.__all__)
if current_platform.is_intel_hpu():
from . import intel_hpu
if hasattr(intel_hpu, "__all__"):
globals().update({name: getattr(intel_hpu, name) for name in intel_hpu.__all__})
__all__.extend(intel_hpu.__all__)

View File

@@ -0,0 +1,26 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
intel_hpu backend methods
"""
from .attention.hpu_attn_backend import HPUAttentionBackend
from .moe.fused_moe_hpu_backend import HpuMoEMethod, HpuTensorWiseFP8MoEMethod
__all__ = [
"HPUAttentionBackend",
"HpuMoEMethod",
"HpuTensorWiseFP8MoEMethod",
]

View File

@@ -0,0 +1,19 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .hpu_attn_backend import HPUAttentionBackend
__all__ = [
"HPUAttentionBackend",
]

View File

@@ -0,0 +1,314 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import os
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional
import paddle
if TYPE_CHECKING:
from paddle._typing.dtype_like import _DTypeLiteral
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
AttentionMetadata,
)
if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import HPUForwardMeta
from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
class AttentionBackend_HPU(AttentionBackend):
"""The base class of attention backends"""
@abstractmethod
def init_attention_metadata(self, forward_meta: HPUForwardMeta):
"""Initialize the forward metadata."""
raise NotImplementedError()
def forward(
self,
src: paddle.Tensor,
qkv_proj: QKVParallelLinear,
o_proj: RowParallelLinear,
layer: paddle.nn.Layer,
forward_meta: HPUForwardMeta,
):
"""
Run a forward.
args:
src: the hidden states tensor
residual_input: the residual tensor
layer: The layer that will be used for the forward.
forward_meta: The forward metadata.
"""
if forward_meta.forward_mode.is_mixed():
return self.forward_mixed(
src,
qkv_proj,
o_proj,
layer,
forward_meta,
)
elif forward_meta.forward_mode.is_decode():
return self.forward_decode(
src,
qkv_proj,
o_proj,
layer,
forward_meta,
)
else:
return self.forward_extend(
src,
qkv_proj,
o_proj,
layer,
forward_meta,
)
def forward_mixed(
self,
src: paddle.Tensor,
qkv_proj: QKVParallelLinear,
o_proj: RowParallelLinear,
layer: paddle.nn.Layer,
forward_meta: HPUForwardMeta,
):
"""Run a forward for mix."""
raise NotImplementedError()
def forward_decode(
self,
src: paddle.Tensor,
qkv_proj: QKVParallelLinear,
o_proj: RowParallelLinear,
layer: paddle.nn.Layer,
forward_meta: HPUForwardMeta,
):
"""Run a forward for decode."""
raise NotImplementedError()
def forward_extend(
self,
src: paddle.Tensor,
qkv_proj: QKVParallelLinear,
o_proj: RowParallelLinear,
layer: paddle.nn.Layer,
forward_meta: HPUForwardMeta,
):
"""Run a forward for extend."""
raise NotImplementedError()
@dataclass
class HPUAttentionMetadata(AttentionMetadata):
"""
HPUAttentionMetadata
"""
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
_dtype: _DTypeLiteral = paddle.bfloat16
encoder_max_partition_size: int = 32768
max_partition_size: int = 32768
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: Optional[paddle.Tensor] = None
decoder_block_shape_q: Optional[paddle.Tensor] = None
_fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
class HPUAttentionBackend(AttentionBackend_HPU):
"""
HPUAttentionBackend backend implementation.
"""
def __init__(self, llm_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int):
"""
HPUAttentionBackend __init__
"""
super().__init__()
self.attention_metadata: HPUAttentionMetadata = None
# TODO(gongshaotian): Use llm_config parameters in the correct location
self.block_size = llm_config.parallel_config.block_size
self.max_seq_len = llm_config.parallel_config.max_model_len
self.rope_theta = 10000.0 if llm_config.model_config.rope_theta is None else llm_config.model_config.rope_theta
self.rope_3d = getattr(llm_config.model_config, "rope_3d", False)
self.causal = getattr(llm_config.model_config, "causal", True)
self.speculative_method: str = llm_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None
self.speculate_max_draft_token_num: int = llm_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = llm_config.speculative_config.model_type == "mtp"
self.rank: int = llm_config.parallel_config.tensor_parallel_rank
self.nranks = llm_config.parallel_config.tensor_parallel_size
self.kv_num_heads = kv_num_heads
self.num_heads = num_heads
self.head_dim = head_dim
self.num_layers = llm_config.model_config.num_hidden_layers
# pd_disaggregation
self.use_pd_disaggregation = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
self.start_layer_index = llm_config.model_config.start_layer_index
def init_attention_metadata(self, forward_meta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
metadata = HPUAttentionMetadata()
metadata.encoder_block_shape_q = 64
metadata.decoder_block_shape_q = 16
metadata.max_partition_size = 32768
metadata.encoder_max_partition_size = 32768
metadata._dtype = paddle.get_default_dtype()
if metadata._dtype == "bfloat16":
metadata._fuse_kernel_compute_dtype = "bf16"
elif metadata._dtype == "float16":
metadata._fuse_kernel_compute_dtype = "fp16"
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"
metadata.block_tables = forward_meta.block_tables
metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
self.attention_metadata = metadata
def get_kv_cache_shape(
self,
max_num_blocks: int,
):
"""
Caculate kv cache shape
"""
return (max_num_blocks, self.block_size, self.kv_num_heads, self.head_dim)
def forward_extend(
self, src, qkv_proj: QKVParallelLinear, o_proj: RowParallelLinear, layer: Attention, forward_meta
):
"""
forward_extend
"""
# metadata = self.attention_metadata
from fastdeploy.model_executor.ops.intel_hpu import (
fused_qkv_rope,
fused_sdpa_proj_t,
index_copy_,
)
query_states, key_value_states = fused_qkv_rope(
src,
qkv_proj.weight,
qkv_proj.bias,
forward_meta.rotary_embs,
self.head_dim,
self.num_heads,
forward_meta.total_batch,
transpose=False,
use_neox_style=layer.use_neox_rotary_style,
)
kv, B, BP_BS, M, H = key_value_states.shape
key_value_states_reshape = key_value_states.reshape([kv, -1, forward_meta.block_size, M, H])
key_states = key_value_states_reshape[0]
value_states = key_value_states_reshape[1]
k_cache = forward_meta.caches[2 * layer.layer_id]
v_cache = forward_meta.caches[2 * layer.layer_id + 1]
index_copy_(k_cache, forward_meta.block_indices, key_states, 0)
index_copy_(v_cache, forward_meta.block_indices, value_states, 0)
out_linear_out = fused_sdpa_proj_t(
query_states,
key_value_states,
forward_meta.attn_mask,
None,
o_proj.weight,
scaling_factor=self.head_dim**-0.5,
causal=True,
softmax_mode=0,
)
if self.nranks > 1:
from fastdeploy.distributed.communication import (
tensor_model_parallel_all_reduce_custom,
)
tensor_model_parallel_all_reduce_custom(out_linear_out)
return out_linear_out
def forward_decode(
self, src, qkv_proj: QKVParallelLinear, o_proj: RowParallelLinear, layer: Attention, forward_meta
):
"""
forward_decode
"""
# metadata = self.attention_metadata
from fastdeploy.model_executor.ops.intel_hpu import fused_block_attention
res = fused_block_attention(
src,
forward_meta.rotary_embs,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
forward_meta.block_groups,
forward_meta.block_list,
forward_meta.block_mapping,
forward_meta.attention_mask,
forward_meta.block_indices,
forward_meta.block_offsets,
qkv_proj.weight,
qkv_proj.bias,
o_proj.weight,
self.head_dim,
self.num_heads,
scaling_factor=self.head_dim**-0.5,
transpose=False,
use_neox_style=layer.use_neox_rotary_style,
)
# all_reduce
if self.nranks > 1:
from fastdeploy.distributed.communication import (
tensor_model_parallel_all_reduce_custom,
)
tensor_model_parallel_all_reduce_custom(res)
return res

View File

@@ -0,0 +1,16 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" "
intel_hpu moe
"""

View File

@@ -0,0 +1,249 @@
"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import paddle
from paddle import nn
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce_custom
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
class HpuMoEMethod(MoEMethodBase):
"""
Use Cutlass Group Gemm to compute Fused MoE.
This method is the oldest way to compute MoE in Paddle.
"""
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
# TODO: split create_parameter from process_loaded_weights
return NotImplemented
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Paddle HPU load weight process.
"""
# bf16
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
for idx, weights_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weights_list = []
for i in range(layer.num_local_experts):
weight_tensor = weights_tensor[i]
weight = layer.create_parameter(
shape=weight_tensor.shape,
dtype=weight_tensor.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
weight.set_value(weight_tensor)
weights_list.append(weight)
weights_name = self.added_weight_attrs[idx]
setattr(layer, weights_name, weights_list)
def apply_ep_prefill(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
"""
raise NotImplementedError
def apply_ep_decode(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
"""
raise NotImplementedError
def apply_tp(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
) -> paddle.Tensor:
"""
Paddle hpu Fused MoE.
"""
if layer.topk_method == "noaux_tc":
raise NotImplementedError
# norm_topk_prob = False if layer.topk_method == "noaux_tc" else True
"""
weights = paddle.nn.functional.softmax(gate_out, axis=-1)
if layer.moe_use_gate_correction_bias:
scores = weights + layer.gate_correction_bias
_, selected_experts = paddle.topk(scores, layer.top_k, axis=-1)
routing_weights = paddle.index_sample(weights, selected_experts)
else:
routing_weights, selected_experts = paddle.topk(weights, layer.top_k, axis=-1)
routing_weights /= paddle.sum(routing_weights, axis=-1, keepdim=True)
common_inputs = (x, selected_experts, routing_weights.cast("bfloat16"))
common_params = (
False, #permuted_weights
"silu", #activation,
0,
layer.num_experts - 1,
)
weights = (
layer.moe_ffn1_weight,
layer.moe_ffn2_weight,
)
fused_moe_out, _ = mixture_of_experts(
*common_inputs, *weights, *common_params, False
)
# if norm_topk_prob:
# routing_weights_norm = paddle.sum(routing_weights, axis=-1, keepdim=True).cast("bfloat16")
# fused_moe_out = fused_moe_out / routing_weights_norm
"""
chunk_size = 64
from fastdeploy.model_executor.ops.intel_hpu import fused_gate_moe
# TODO: fuse matmul to gate_moe
gate_out = paddle.matmul(x.cast("float32"), gate.weight)
fused_moe_out = fused_gate_moe(
x,
gate_out,
layer.gate_correction_bias,
layer.up_gate_proj_weight,
layer.down_proj_weight,
layer.top_k,
layer.moe_use_gate_correction_bias,
norm_topk_prob=True,
permuted_weights=False,
activation="silu",
experts_min=layer.expert_id_offset,
experts_max=layer.expert_id_offset + layer.num_local_experts - 1,
chunk_size=chunk_size,
)
if layer.reduce_results and layer.tp_size > 1:
tensor_model_parallel_all_reduce_custom(fused_moe_out)
return fused_moe_out
class HpuTensorWiseFP8MoEMethod(HpuMoEMethod):
"""
Use Cutlass Group Gemm to compute Fused MoE.
This method is the oldest way to compute MoE in Paddle.
"""
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
# TODO: split create_parameter from process_loaded_weights
return NotImplemented
def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Paddle HPU load weight process.
"""
# bf16
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
from fastdeploy.model_executor.ops.intel_hpu import fused_quant
self.quant_fn = fused_quant
self.moe_quant_type = "tensor_wise_fp8"
for idx, weights_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weights_name = self.added_weight_attrs[idx]
scales_name = self.added_scale_attrs[idx]
weights_list = []
scales_list = []
for i in range(layer.num_local_experts):
# quantize loaded weights
quant_weight, scale = self.quant_fn(weights_tensor[i])
weights_list.append(quant_weight)
scales_list.append(scale)
setattr(layer, weights_name, weights_list)
setattr(layer, scales_name, scales_list)
def apply_ep_prefill(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Apply the EP prefill method.
"""
raise NotImplementedError
def apply_ep_decode(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate_out: paddle.Tensor,
) -> paddle.Tensor:
"""
Apply the EP decoder method.
"""
raise NotImplementedError
def apply_tp(
self,
layer: nn.Layer,
x: paddle.Tensor,
gate: nn.Layer,
) -> paddle.Tensor:
"""
Paddle hpu Fused MoE.
"""
if layer.topk_method == "noaux_tc":
raise NotImplementedError
# norm_topk_prob = False if layer.topk_method == "noaux_tc" else True
chunk_size = 64
from fastdeploy.model_executor.ops.intel_hpu import fused_gate_moe_fp8
# TODO: fuse matmul to gate_moe
gate_out = paddle.matmul(x.cast("float32"), gate.weight)
fused_moe_out = fused_gate_moe_fp8(
x,
gate_out,
layer.gate_correction_bias,
layer.up_gate_proj_weight,
layer.down_proj_weight,
None, # intermediate_hidden_states_scales
layer.up_gate_proj_weight_scale,
layer.down_proj_weight_scale,
layer.top_k,
layer.moe_use_gate_correction_bias,
norm_topk_prob=True,
permuted_weights=False,
activation="silu",
experts_min=layer.expert_id_offset,
experts_max=layer.expert_id_offset + layer.num_local_experts - 1,
chunk_size=chunk_size,
)
if layer.reduce_results and layer.tp_size > 1:
tensor_model_parallel_all_reduce_custom(fused_moe_out)
return fused_moe_out

View File

@@ -116,6 +116,7 @@ class LinearBase(nn.Layer):
or current_platform.is_gcu()
or current_platform.is_dcu()
or current_platform.is_maca()
or current_platform.is_intel_hpu()
):
self.forward = self.forward_cuda
else:

View File

@@ -56,6 +56,11 @@ def get_moe_method():
)
return MetaxTritonWeightOnlyMoEMethod(None)
elif current_platform.is_intel_hpu():
from fastdeploy.model_executor.layers.backends import HpuMoEMethod
return HpuMoEMethod(None)
# return HpuTensorWiseFP8MoEMethod(None)
raise NotImplementedError
@@ -139,6 +144,7 @@ class FusedMoE(nn.Layer):
self.hidden_size = fd_config.model_config.hidden_size
self.num_experts = num_experts
self.num_local_experts = self.num_experts // self.ep_size
self.moe_intermediate_size = moe_intermediate_size // self.tp_size

View File

@@ -69,6 +69,12 @@ class ErnieRotaryEmbedding:
.transpose([0, 1, 2, 4, 3])
.reshape([2, bsz, max_seq_len, 1, self.rotary_dim])
)
if paddle.is_compiled_with_custom_device("intel_hpu"):
return (
paddle.concat([rot_emb, rot_emb], axis=3)
.transpose([0, 1, 2, 4, 3])
.reshape([2, bsz, max_seq_len, 1, self.rotary_dim])
)
else:
return rot_emb

View File

@@ -54,3 +54,6 @@ class SamplingMetadata:
temp_scaled_logprobs: Optional[paddle.Tensor] = None
top_p_normalized_logprobs: Optional[paddle.Tensor] = None
share_inputs: Optional[Dict[str, paddle.Tensor]] = None
# Add for HPU post-processing
seq_lens_encoder: Optional[paddle.Tensor] = None
seq_lens_decoder: Optional[paddle.Tensor] = None

View File

@@ -136,6 +136,23 @@ def apply_penalty_multi_scores(
min_dec_lens,
eos_token_ids,
)
elif current_platform.is_intel_hpu():
from fastdeploy.model_executor.ops.intel_hpu import (
get_token_penalty_multi_scores,
)
logits = get_token_penalty_multi_scores(
pre_token_ids,
logits,
repetition_penalties,
frequency_penalties,
presence_penalties,
temperature,
bad_words_token_ids,
step_idx,
min_dec_lens,
eos_token_ids,
)
else:
raise NotImplementedError

View File

@@ -209,6 +209,8 @@ class Sampler(nn.Layer):
or current_platform.is_maca()
):
self.forward = self.forward_cuda
elif current_platform.is_intel_hpu():
self.forward = self.forward_intel_hpu
else:
raise NotImplementedError
@@ -377,6 +379,49 @@ class Sampler(nn.Layer):
return sampler_output
def forward_intel_hpu(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
batch_ids: paddle.Tensor,
max_batch: int,
rank: int,
local_rank: int,
) -> paddle.Tensor:
if logits.dtype != paddle.float32:
logits = paddle.cast(logits, paddle.float32)
from fastdeploy.model_executor.ops.intel_hpu import fused_sampler
_, next_tokens = fused_sampler(
sampling_metadata.pre_token_ids,
sampling_metadata.prompt_ids,
sampling_metadata.seq_lens_encoder,
sampling_metadata.seq_lens_decoder,
sampling_metadata.step_idx,
sampling_metadata.stop_flags,
logits,
sampling_metadata.repetition_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.presence_penalties,
sampling_metadata.temperature,
sampling_metadata.bad_words_token_ids,
sampling_metadata.step_idx,
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
sampling_metadata.top_p,
rank,
local_rank,
)
if next_tokens.shape[0] != max_batch:
dim = next_tokens.shape[-1]
tmp_tokens = paddle.full((max_batch, dim), -1, dtype=next_tokens.dtype)
tmp_tokens = paddle.scatter(tmp_tokens, batch_ids, next_tokens[: batch_ids.shape[0], :])
return tmp_tokens
return next_tokens
class SpeculativeSampler(nn.Layer):
"""

View File

@@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""fastdeploy module"""
from . import cpu, gcu, gpu, iluvatar, npu, xpu
from . import cpu, gcu, gpu, iluvatar, intel_hpu, npu, xpu
__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar", "gcu"]
__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar", "gcu", "intel_hpu"]

View File

@@ -0,0 +1,21 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""fastdeploy intel_hpu ops."""
from fastdeploy.import_ops import import_custom_ops
# PACKAGE = "fastdeploy.model_executor.ops.intel_hpu"
PACKAGE = "paddlenlp_ops"
import_custom_ops(PACKAGE, "paddlenlp_ops", globals())

View File

@@ -56,6 +56,8 @@ elif current_platform.is_maca():
update_inputs,
update_inputs_v1,
)
elif current_platform.is_intel_hpu():
pass
else:
from fastdeploy.model_executor.ops.gpu import (
get_padding_offset,

View File

@@ -284,6 +284,8 @@ class TokenProcessor:
from fastdeploy.model_executor.ops.iluvatar import get_output
elif current_platform.is_gcu():
from fastdeploy.model_executor.ops.gcu import get_output
elif current_platform.is_intel_hpu():
from fastdeploy.model_executor.ops.intel_hpu import get_output
else:
from fastdeploy.model_executor.ops.gpu import (
get_output,

View File

@@ -23,6 +23,7 @@ from .cuda import CUDAPlatform
from .dcu import DCUPlatform
from .gcu import GCUPlatform
from .iluvatar import IluvatarPlatform
from .intel_hpu import INTEL_HPUPlatform
from .maca import MACAPlatform
from .npu import NPUPlatform
from .xpu import XPUPlatform
@@ -43,6 +44,8 @@ def __getattr__(name: str):
_current_platform = XPUPlatform()
elif paddle.is_compiled_with_custom_device("npu"):
_current_platform = NPUPlatform()
elif paddle.is_compiled_with_custom_device("intel_hpu"):
_current_platform = INTEL_HPUPlatform()
elif paddle.is_compiled_with_custom_device("iluvatar_gpu"):
_current_platform = IluvatarPlatform()
elif paddle.is_compiled_with_custom_device("gcu"):

View File

@@ -27,6 +27,7 @@ class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
BLOCK_ATTN = enum.auto()
PLAS_ATTN = enum.auto()
HPU_ATTN = enum.auto()
class Platform:
@@ -54,6 +55,12 @@ class Platform:
"""
return paddle.is_compiled_with_xpu()
def is_intel_hpu(self) -> bool:
"""
whether platform is intel_hpu
"""
return paddle.is_compiled_with_custom_device("intel_hpu")
def is_cpu(self) -> bool:
"""
whether platform is cpu

View File

@@ -0,0 +1,52 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from fastdeploy.utils import console_logger as logger
from .base import Platform, _Backend
class INTEL_HPUPlatform(Platform):
device_name = "intel_hpu"
@classmethod
def available(self):
"""
Check whether Intel HPU is available.
"""
try:
assert paddle.base.core.get_custom_device_count("intel_hpu") > 0
return True
except Exception as e:
logger.warning(
"You are using Intel HPU platform, but there is no Intel HPU "
"detected on your machine. Maybe Intel HPU devices is not set properly."
f"\n Original Error is {e}"
)
return False
@classmethod
def get_attention_backend_cls(cls, selected_backend):
"""
get_attention_backend_cls
"""
if selected_backend == _Backend.NATIVE_ATTN:
logger.info("Using NATIVE ATTN backend.")
return "fastdeploy.model_executor.layers.attention.PaddleNativeAttnBackend"
elif selected_backend == _Backend.HPU_ATTN:
logger.info("Using HPU ATTN backend.")
return "fastdeploy.model_executor.layers.backends.intel_hpu.attention.HPUAttentionBackend"
else:
logger.warning("Other backends are not supported for now.")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,213 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import gc
import os
import time
from typing import List, Optional
import paddle
import paddle.nn as nn
from paddle.base import core
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.utils import get_logger, set_random_seed
from fastdeploy.worker.hpu_model_runner import HPUModelRunner
from fastdeploy.worker.output import ModelRunnerOutput
from fastdeploy.worker.worker_base import WorkerBase
logger = get_logger("hpu_worker", "hpu_worker.log")
def max_memory_allocated(device_id: int) -> int:
return core.device_memory_stat_peak_value("Allocated", device_id)
def max_memory_reserved(device_id: int) -> int:
return core.device_memory_stat_peak_value("Reserved", device_id)
def reset_max_memory_allocated(device_id: int) -> None:
core.device_memory_stat_reset_peak_value("Allocated", device_id)
def reset_max_memory_reserved(device_id: int) -> None:
core.device_memory_stat_reset_peak_value("Reserved", device_id)
class HpuWorker(WorkerBase):
def __init__(
self,
fd_config: FDConfig,
local_rank: int,
rank: int,
):
super().__init__(
fd_config=fd_config,
local_rank=local_rank,
rank=rank,
)
pass
def init_device(self):
"""
Initialize device and construct model runner
"""
if paddle.is_compiled_with_custom_device("intel_hpu"):
# Set environment variable
self.device_ids = self.parallel_config.device_ids.split(",")
logger.info(
f"Using Intel HPU device with local rank => device id: {int(self.device_ids[self.local_rank])} as module id"
)
intel_hpus_module_id = int(self.device_ids[self.local_rank])
self.device = f"intel_hpu:{intel_hpus_module_id}"
paddle.device.set_device(self.device)
paddle.set_default_dtype(self.parallel_config.dtype)
gc.collect()
paddle.device.cuda.empty_cache()
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
set_random_seed(self.fd_config.model_config.seed)
# Construct model runner
self.model_runner: HPUModelRunner = HPUModelRunner(
fd_config=self.fd_config,
device=self.device,
device_id=self.device_ids[self.local_rank],
rank=self.rank,
local_rank=self.local_rank,
)
def exist_prefill(self):
"""
check whether prefill stage exist
"""
return self.model_runner.exist_prefill()
def determine_available_memory(self) -> int:
"""
Profiles the peak memory usage of the model to determine how much
memory can be used for KV cache without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
Tip:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# 1. Record memory state before profile run
start_time = time.perf_counter()
module_id = int(self.device_ids[self.local_rank])
reset_max_memory_allocated(module_id)
reset_max_memory_reserved(module_id)
paddle_reserved_mem_before_run = max_memory_reserved(module_id)
paddle_allocated_mem_before_run = max_memory_allocated(module_id) # not reserved
logger.info(
(
"Before running the profile, the memory usage info is as follows:",
f"\nPaddle reserved memory: {paddle_reserved_mem_before_run}",
f"\nPaddle allocated memory: {paddle_allocated_mem_before_run}",
)
)
# 2. Profile run
self.model_runner.profile_run()
# 3. Statistical memory information
paddle_reserved_mem_after_run = max_memory_reserved(module_id)
paddle_allocated_mem_after_run = max_memory_allocated(module_id)
one_mb = 1024 * 1024
one_gb = 1024 * one_mb
hpu_reserved_memory = 768 * one_mb # 768MB reserved for not paddle use memory
hpu_total_memory = 96 * one_gb # 96GB HPU memory
peak_memory = paddle_allocated_mem_after_run + hpu_reserved_memory
available_kv_cache_memory = hpu_total_memory * self.cache_config.gpu_memory_utilization - peak_memory
end_time = time.perf_counter()
logger.info(
(
"After running the profile, the memory usage info is as follows:",
f"\nPaddle reserved memory: {paddle_reserved_mem_after_run}",
f"\nPaddle allocated memory: {paddle_allocated_mem_after_run}",
f"\nAvailable KV Cache meomory: {available_kv_cache_memory}",
f"Profile time: {end_time - start_time}",
)
)
return available_kv_cache_memory # return to caculate the block num in this device
def load_model(self) -> None:
"""Load model"""
self.model_runner.load_model()
def get_model(self) -> nn.Layer:
"""Get current model"""
return self.model_runner.get_model()
def initialize_cache(self, num_gpu_blocks: int) -> None:
"""Initialize the KV Cache with accurate num_gpu_blocks"""
# accurate cache size
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
def execute_model(
self,
model_forward_batch: Optional[List[Request]] = None,
num_running_request: int = None,
) -> Optional[ModelRunnerOutput]:
""" """
output = self.model_runner.execute_model(model_forward_batch)
return output
def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int) -> None:
"""Process new requests and then start the decode loop
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
and workers and modelrunners should not perceive it.
"""
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts, num_running_requests=num_running_requests)
def graph_optimize_and_warm_up_model(self) -> None:
"""
Perform the warm-up and the graph optimization
"""
# wait for all cards loading model completely.
if self.rank > 1:
paddle.distributed.barrier()
# 1. Warm up model
# NOTE(gongshaotian): may be not need warm_up at this place
if int(os.environ.get("HPU_WARMUP_BUCKET", 0)) == 1:
logger.info("Warmup bucket is enabled, start warmup bucket")
self.model_runner.is_warmuping = True
self.model_runner.warm_up_bucket()
self.model_runner.is_warmuping = False
else:
logger.info("Skipping warmup bucket, please set HPU_WARMUP_BUCKET=1 to enable it.")
# 2. Triger cuda grpah capture
self.model_runner.capture_model()
def check_health(self) -> bool:
""" """
return True
def cal_theortical_kvcache(self) -> int:
"""Calculate the block memory required"""
return self.model_runner.cal_theortical_kvcache()

View File

@@ -82,6 +82,10 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase:
from fastdeploy.worker.metax_worker import MetaxWorker
return MetaxWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
if current_platform.is_intel_hpu():
from fastdeploy.worker.hpu_worker import HpuWorker
return HpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank)
def init_distributed_environment(seed: int = 20) -> Tuple[int, int]:
@@ -89,21 +93,22 @@ def init_distributed_environment(seed: int = 20) -> Tuple[int, int]:
# Global rank
ranks = dist.get_world_size()
dist_strategy = fleet.DistributedStrategy()
if ranks > 0:
dist_strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": ranks,
"pp_degree": 1,
"sharding_degree": 1,
}
dist_strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": ranks,
"pp_degree": 1,
"sharding_degree": 1,
}
# Set control in tensor parallel
dist_strategy.tensor_parallel_configs = {"tensor_init_seed": seed}
fleet.init(is_collective=True, strategy=dist_strategy)
# Local rank
local_rank = fleet.worker_index()
# Set control in tensor parallel
dist_strategy.tensor_parallel_configs = {"tensor_init_seed": seed}
fleet.init(is_collective=True, strategy=dist_strategy)
# Local rank
local_rank = fleet.worker_index()
else:
local_rank = 0
return ranks, local_rank

View File

@@ -174,6 +174,8 @@ def get_device_type():
return "gcu"
elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
return "metax-gpu"
elif paddle.is_compiled_with_custom_device("intel_hpu"):
return "intel-hpu"
else:
return "cpu"