From f1b5392e2032d21518c5567edd55c5735ab9bfdd Mon Sep 17 00:00:00 2001 From: fmiao2372 Date: Wed, 24 Sep 2025 12:27:50 +0800 Subject: [PATCH] [Intel HPU] Support intel hpu platform (#4161) * [Intel HPU] Support intel hpu platform * fix some issues * apply precommit and move AttentionBackend_HPU * fix format issue * correct ops import * fix ci issue * update code in layers * fix code style issue * remove dense tp moe ep mode * fix enc_dec_block_num * fix rebase issue * rename hpu to gaudi in readme * rename ForwardMeta_HPU to HPUForwardMeta --- README.md | 3 +- README_CN.md | 3 +- build.sh | 10 +- custom_ops/setup_ops.py | 2 + docs/get_started/installation/README.md | 1 + docs/get_started/installation/intel_gaudi.md | 75 + docs/zh/get_started/installation/README.md | 1 + .../get_started/installation/intel_gaudi.md | 75 + fastdeploy/config.py | 2 + fastdeploy/distributed/communication.py | 23 + fastdeploy/model_executor/forward_meta.py | 117 +- .../model_executor/layers/activation.py | 12 + .../layers/backends/__init__.py | 7 + .../layers/backends/intel_hpu/__init__.py | 26 + .../backends/intel_hpu/attention/__init__.py | 19 + .../intel_hpu/attention/hpu_attn_backend.py | 314 ++++ .../layers/backends/intel_hpu/moe/__init__.py | 16 + .../intel_hpu/moe/fused_moe_hpu_backend.py | 249 +++ fastdeploy/model_executor/layers/linear.py | 1 + fastdeploy/model_executor/layers/moe/moe.py | 6 + .../model_executor/layers/rotary_embedding.py | 6 + .../model_executor/layers/sample/meta_data.py | 3 + .../sample/ops/apply_penalty_multi_scores.py | 17 + .../model_executor/layers/sample/sampler.py | 45 + fastdeploy/model_executor/ops/__init__.py | 4 +- .../model_executor/ops/intel_hpu/__init__.py | 21 + .../model_executor/pre_and_post_process.py | 2 + fastdeploy/output/token_processor.py | 2 + fastdeploy/platforms/__init__.py | 3 + fastdeploy/platforms/base.py | 7 + fastdeploy/platforms/intel_hpu.py | 52 + fastdeploy/worker/hpu_model_runner.py | 1463 +++++++++++++++++ fastdeploy/worker/hpu_worker.py | 213 +++ fastdeploy/worker/worker_process.py | 31 +- setup.py | 2 + 35 files changed, 2814 insertions(+), 19 deletions(-) create mode 100644 docs/get_started/installation/intel_gaudi.md create mode 100644 docs/zh/get_started/installation/intel_gaudi.md create mode 100644 fastdeploy/model_executor/layers/backends/intel_hpu/__init__.py create mode 100644 fastdeploy/model_executor/layers/backends/intel_hpu/attention/__init__.py create mode 100644 fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py create mode 100644 fastdeploy/model_executor/layers/backends/intel_hpu/moe/__init__.py create mode 100644 fastdeploy/model_executor/layers/backends/intel_hpu/moe/fused_moe_hpu_backend.py create mode 100644 fastdeploy/model_executor/ops/intel_hpu/__init__.py create mode 100644 fastdeploy/platforms/intel_hpu.py create mode 100644 fastdeploy/worker/hpu_model_runner.py create mode 100644 fastdeploy/worker/hpu_worker.py diff --git a/README.md b/README.md index c6d1ccaad..dae9dca6b 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ English | [简体中文](README_CN.md) - 🤝 **OpenAI API Server and vLLM Compatible**: One-command deployment with [vLLM](https://github.com/vllm-project/vllm/) interface compatibility. - 🧮 **Comprehensive Quantization Format Support**: W8A16, W8A8, W4A16, W4A8, W2A16, FP8, and more. - ⏩ **Advanced Acceleration Techniques**: Speculative decoding, Multi-Token Prediction (MTP) and Chunked Prefill. -- 🖥️ **Multi-Hardware Support**: NVIDIA GPU, Kunlunxin XPU, Hygon DCU, Ascend NPU, Iluvatar GPU, Enflame GCU, MetaX GPU etc. +- 🖥️ **Multi-Hardware Support**: NVIDIA GPU, Kunlunxin XPU, Hygon DCU, Ascend NPU, Iluvatar GPU, Enflame GCU, MetaX GPU, Intel Gaudi etc. ## Requirements @@ -60,6 +60,7 @@ FastDeploy supports inference deployment on **NVIDIA GPUs**, **Kunlunxin XPUs**, - [Enflame GCU](./docs/get_started/installation/Enflame_gcu.md) - [Hygon DCU](./docs/get_started/installation/hygon_dcu.md) - [MetaX GPU](./docs/get_started/installation/metax_gpu.md) +- [Intel Gaudi](./docs/get_started/installation/intel_gaudi.md) **Note:** We are actively working on expanding hardware support. Additional hardware platforms including Ascend NPU are currently under development and testing. Stay tuned for updates! diff --git a/README_CN.md b/README_CN.md index 607acd8c6..0f6460fa5 100644 --- a/README_CN.md +++ b/README_CN.md @@ -41,7 +41,7 @@ - 🤝 **OpenAI API服务与vLLM兼容**:单命令部署,兼容[vLLM](https://github.com/vllm-project/vllm/)接口 - 🧮 **全量化格式支持**:W8A16、W8A8、W4A16、W4A8、W2A16、FP8等 - ⏩ **高级加速技术**:推测解码、多令牌预测(MTP)及分块预填充 -- 🖥️ **多硬件支持**:NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU等 +- 🖥️ **多硬件支持**:NVIDIA GPU、昆仑芯XPU、海光DCU、昇腾NPU、天数智芯GPU、燧原GCU、沐曦GPU、英特尔Gaudi等 ## 要求 @@ -58,6 +58,7 @@ FastDeploy 支持在**英伟达(NVIDIA)GPU**、**昆仑芯(Kunlunxin)XPU - [燧原 S60](./docs/zh/get_started/installation/Enflame_gcu.md) - [海光 DCU](./docs/zh/get_started/installation/hygon_dcu.md) - [沐曦 GPU](./docs/zh/get_started/installation/metax_gpu.md) +- [英特尔 Gaudi](./docs/zh/get_started/installation/intel_gaudi.md) **注意:** 我们正在积极拓展硬件支持范围。目前,包括昇腾(Ascend)NPU 等其他硬件平台正在开发测试中。敬请关注更新! diff --git a/build.sh b/build.sh index d8b27d03b..0596d8f99 100644 --- a/build.sh +++ b/build.sh @@ -128,6 +128,12 @@ function copy_ops(){ echo -e "MACA ops have been copy to fastdeploy" return fi + is_intel_hpu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('intel_hpu'))"` + if [ "$is_intel_hpu" = "True" ]; then + DEVICE_TYPE="intel-hpu" + echo -e "intel_hpu ops have been copy to fastdeploy" + return + fi DEVICE_TYPE="cpu" cd ../../../../ @@ -159,7 +165,9 @@ function build_and_install_ops() { else FD_BUILDING_ARCS=${FD_BUILDING_ARCS} ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR} fi - find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \; + if [ -d "${OPS_TMP_DIR}" ]; then + find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \; + fi else echo "Error: Invalid parameter '$FD_CPU_USE_BF16'. Please use true or false." exit 1 diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 4fa6316c5..3ca8c3c3f 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -623,6 +623,8 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"): ], ), ) +elif paddle.is_compiled_with_custom_device("intel_hpu"): + pass else: use_bf16 = envs.FD_CPU_USE_BF16 == "True" diff --git a/docs/get_started/installation/README.md b/docs/get_started/installation/README.md index ba7042e26..76dba9d00 100644 --- a/docs/get_started/installation/README.md +++ b/docs/get_started/installation/README.md @@ -7,3 +7,4 @@ FastDeploy currently supports installation on the following hardware platforms: - [Enflame S60 GCU Installation](Enflame_gcu.md) - [Iluvatar GPU Installation](iluvatar_gpu.md) - [Hygon DCU Installation](hygon_dcu.md) +- [Intel Gaudi Installation](intel_gaudi.md) diff --git a/docs/get_started/installation/intel_gaudi.md b/docs/get_started/installation/intel_gaudi.md new file mode 100644 index 000000000..93c5504fd --- /dev/null +++ b/docs/get_started/installation/intel_gaudi.md @@ -0,0 +1,75 @@ +# Intel Gaudi Installation for running ERNIE 4.5 Series Models + +The following installation methods are available when your environment meets these requirements: + +- Python 3.10 +- Intel Gaudi 2 +- Intel Gaudi software version 1.22.0 +- Linux X86_64 + +## 1. Run Docker Container + +Use the following commands to run a Docker container. Make sure to update the versions below as listed in the [Support Matrix](https://docs.habana.ai/en/latest/Support_Matrix/Support_Matrix.html): + +```{.console} +$ docker pull vault.habana.ai/gaudi-docker/1.22.0/ubuntu22.04/habanalabs/pytorch-installer-2.7.1:latest +$ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.22.0/ubuntu22.04/habanalabs/pytorch-installer-2.7.1:latest +``` + +### 2. Install PaddlePaddle + +```bash +python -m pip install paddlepaddle==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/ +``` + +### 3. Install PaddleCustomDevice +```shell +git clone https://github.com/PaddlePaddle/PaddleCustomDevice +cd PaddleCustomDevice/backends/intel_hpu/ +mkdir -p build +cd build +cmake .. +make -j +pip install --force-reinstall dist/paddle_intel_hpu*.whl +cd PaddleCustomDevice/backends/intel_hpu/custom_ops +python setup.py install +``` + +### 4. Install FastDeploy + +```shell +git clone https://github.com/PaddlePaddle/FastDeploy +cd FastDeploy +bash build.sh +``` + +## Prepare the inference demo + +### 1. Start inference service +```shell +export GC_KERNEL_PATH=/usr/lib/habanalabs/libtpc_kernels.so +export GC_KERNEL_PATH=/usr/local/lib/python3.10/dist-packages/paddle_custom_device/intel_hpu/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH +export INTEL_HPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export PADDLE_DISTRI_BACKEND=xccl +export PADDLE_XCCL_BACKEND=intel_hpu +export HABANA_PROFILE=0 +export HPU_VISIBLE_DEVICES=0 + +HPU_WARMUP_BUCKET=1 HPU_WARMUP_MODEL_LEN=4096 FD_ATTENTION_BACKEND=HPU_ATTN python -m fastdeploy.entrypoints.openai.api_server --model ERNIE-4.5-21B-A3B-Paddle --tensor-parallel-size 1 --max-model-len 32768 --max-num-seqs 128 +``` + +### 2. Launch the request +```bash +curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \ +-H "Content-Type: application/json" \ +-d '{ + "messages": [ + {"role": "user", "content": "What is AI?"} + ], "max_tokens": 24 +}' +``` + +### 3. Successfully returns the result +```json +{"id":"chatcmpl-3bd98ae2-fafe-46ae-a552-d653a8526503","object":"chat.completion","created":1757653575,"model":"ERNIE-4.5-21B-A3B-Paddle","choices":[{"index":0,"message":{"role":"assistant","content":"**AI (Artificial Intelligence)** refers to the development of computer systems that can perform tasks typically requiring human intelligence.","multimodal_content":null,"reasoning_content":null,"tool_calls":null,"prompt_token_ids":null,"completion_token_ids":null,"text_after_process":null,"raw_prediction":null,"prompt_tokens":null,"completion_tokens":null},"logprobs":null,"finish_reason":"length"}],"usage":{"prompt_tokens":11,"total_tokens":35,"completion_tokens":24,"prompt_tokens_details":{"cached_tokens":0}}} +``` diff --git a/docs/zh/get_started/installation/README.md b/docs/zh/get_started/installation/README.md index 68fdbec52..4c1b6016d 100644 --- a/docs/zh/get_started/installation/README.md +++ b/docs/zh/get_started/installation/README.md @@ -7,3 +7,4 @@ FastDeploy支持如下硬件平台: - [Enflame S60 GCU Installation](Enflame_gcu.md) - [Iluvatar GPU Installation](iluvatar_gpu.md) - [Hygon DCU Installation](hygon_dcu.md) +- [Intel Gaudi Installation](intel_gaudi.md) diff --git a/docs/zh/get_started/installation/intel_gaudi.md b/docs/zh/get_started/installation/intel_gaudi.md new file mode 100644 index 000000000..e8b46aa0b --- /dev/null +++ b/docs/zh/get_started/installation/intel_gaudi.md @@ -0,0 +1,75 @@ +# 使用 Intel Gaudi 运行ERNIE 4.5 系列模型 + +在环境满足如下条件前提下 + +- Python 3.10 +- Intel Gaudi 2 +- Intel Gaudi software version 1.22.0 +- Linux X86_64 + +## 1. 运行Docker容器 + +使用下面命令运行Docker容器. 确保更新的版本在如下列表中 [Support Matrix](https://docs.habana.ai/en/latest/Support_Matrix/Support_Matrix.html): + +```{.console} +$ docker pull vault.habana.ai/gaudi-docker/1.22.0/ubuntu22.04/habanalabs/pytorch-installer-2.7.1:latest +$ docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.22.0/ubuntu22.04/habanalabs/pytorch-installer-2.7.1:latest +``` + +### 2. 安装 PaddlePaddle + +```bash +python -m pip install paddlepaddle==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/cpu/ +``` + +### 3. 安装 PaddleCustomDevice +```shell +git clone https://github.com/PaddlePaddle/PaddleCustomDevice +cd PaddleCustomDevice/backends/intel_hpu/ +mkdir -p build +cd build +cmake .. +make -j +pip install --force-reinstall dist/paddle_intel_hpu*.whl +cd PaddleCustomDevice/backends/intel_hpu/custom_ops +python setup.py install +``` + +### 4. 安装 FastDeploy + +```shell +git clone https://github.com/PaddlePaddle/FastDeploy +cd FastDeploy +bash build.sh +``` + +## 准备推理示例 + +### 1. 启动推理服务 +```shell +export GC_KERNEL_PATH=/usr/lib/habanalabs/libtpc_kernels.so +export GC_KERNEL_PATH=/usr/local/lib/python3.10/dist-packages/paddle_custom_device/intel_hpu/libcustom_tpc_perf_lib.so:$GC_KERNEL_PATH +export INTEL_HPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export PADDLE_DISTRI_BACKEND=xccl +export PADDLE_XCCL_BACKEND=intel_hpu +export HABANA_PROFILE=0 +export HPU_VISIBLE_DEVICES=0 + +HPU_WARMUP_BUCKET=1 HPU_WARMUP_MODEL_LEN=4096 FD_ATTENTION_BACKEND=HPU_ATTN python -m fastdeploy.entrypoints.openai.api_server --model ERNIE-4.5-21B-A3B-Paddle --tensor-parallel-size 1 --max-model-len 32768 --max-num-seqs 128 +``` + +### 2. 发送请求 +```bash +curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \ +-H "Content-Type: application/json" \ +-d '{ + "messages": [ + {"role": "user", "content": "What is AI?"} + ], "max_tokens": 24 +}' +``` + +### 3. 成功返回结果 +```json +{"id":"chatcmpl-3bd98ae2-fafe-46ae-a552-d653a8526503","object":"chat.completion","created":1757653575,"model":"ERNIE-4.5-21B-A3B-Paddle","choices":[{"index":0,"message":{"role":"assistant","content":"**AI (Artificial Intelligence)** refers to the development of computer systems that can perform tasks typically requiring human intelligence.","multimodal_content":null,"reasoning_content":null,"tool_calls":null,"prompt_token_ids":null,"completion_token_ids":null,"text_after_process":null,"raw_prediction":null,"prompt_tokens":null,"completion_tokens":null},"logprobs":null,"finish_reason":"length"}],"usage":{"prompt_tokens":11,"total_tokens":35,"completion_tokens":24,"prompt_tokens_details":{"cached_tokens":0}}} +``` diff --git a/fastdeploy/config.py b/fastdeploy/config.py index e9df1c52e..c1953577d 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1498,6 +1498,8 @@ class FDConfig: self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES", self.device_ids) if current_platform.is_xpu(): self.device_ids = os.getenv("XPU_VISIBLE_DEVICES", self.device_ids) + if current_platform.is_intel_hpu(): + self.device_ids = os.getenv("HPU_VISIBLE_DEVICES", self.device_ids) self.read_from_config() self.postprocess() diff --git a/fastdeploy/distributed/communication.py b/fastdeploy/distributed/communication.py index bcc45f140..52f2fbddb 100644 --- a/fastdeploy/distributed/communication.py +++ b/fastdeploy/distributed/communication.py @@ -66,3 +66,26 @@ try: except: tensor_model_parallel_all_reduce = None + +from paddle.distributed.communication import stream +from paddle.distributed.communication.reduce import ReduceOp + + +def all_reduce( + tensor, + op, + group, + sync_op: bool = True, +): + return stream.all_reduce(tensor, op=op, group=group, sync_op=sync_op, use_calc_stream=True) + + +@paddle.jit.marker.unified +def tensor_model_parallel_all_reduce_custom(input_: paddle.Tensor) -> paddle.Tensor: + """All-reduce the input tensor across model parallel group on calc stream.""" + if paddle.in_dynamic_mode(): + hcg = dist.fleet.get_hybrid_communicate_group() + mp_group = hcg.get_model_parallel_group() + all_reduce(input_, op=ReduceOp.SUM, group=mp_group) + else: + dist.all_reduce(input_) diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index f0888302d..2d812b4e9 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -17,12 +17,14 @@ import logging from dataclasses import dataclass from enum import IntEnum, auto -from typing import Optional +from typing import TYPE_CHECKING, Dict, Optional import paddle from fastdeploy.model_executor.layers.attention import AttentionBackend +if TYPE_CHECKING: + from fastdeploy.model_executor.layers.attention import AttentionBackend_HPU logger = logging.getLogger(__name__) @@ -240,3 +242,116 @@ class DCUForwardMeta(ForwardMeta): # Accumulated offset cum_offsets: Optional[paddle.Tensor] = None + + +@dataclass +class HPUForwardMeta: + """ + HPUForwardMeta is used to store the global meta information of the forward on intel HPU. + """ + + # + input_ids: paddle.Tensor + + # attention meta + forward_mode: ForwardMode = ForwardMode.MIXED + + # + ids_remove_padding: paddle.Tensor = None + + # + seq_lens_encoder: Optional[paddle.Tensor] = None + + # + seq_lens_decoder: Optional[paddle.Tensor] = None + + # + seq_lens_this_time: Optional[paddle.Tensor] = None + + # + cum_offsets: Optional[paddle.Tensor] = None + + # + block_tables: Optional[paddle.Tensor] = None + + # + block_groups: Optional[paddle.Tensor] = None + + # + block_list: Optional[paddle.Tensor] = None + + # + block_indices: Optional[paddle.Tensor] = None + + # + block_offsets: Optional[paddle.Tensor] = None + + # + block_mapping: Optional[paddle.Tensor] = None + + # + attention_mask: Optional[paddle.Tensor] = None + + # + block_size: Optional[paddle.Tensor] = None + + # + batch_ids: Optional[paddle.Tensor] = None + + # + total_batch: Optional[paddle.Tensor] = None + + # + is_prompt: Optional[paddle.Tensor] = None + + # + attn_backend: "AttentionBackend_HPU" = None + + # + rotary_embs: Optional[paddle.Tensor] = None + + # + caches: Optional[paddle.Tensor] = None + + # + attn_mask: Optional[paddle.Tensor] = None + + # + pre_caches_length: int = 0 + + @classmethod + def init_forward_meta(cls, share_inputs: Dict, attn_backend: "AttentionBackend_HPU"): + """init forward meta""" + # TODO(gongshaotian): delete this func + is_prompt = share_inputs["is_prompt"] + forward_mode = ForwardMode.DECODE + if is_prompt: + forward_mode = ForwardMode.EXTEND + ret = cls( + forward_mode=forward_mode, + input_ids=share_inputs["input_ids"], + ids_remove_padding=share_inputs["ids_remove_padding"], + seq_lens_encoder=share_inputs["seq_lens_encoder"], + seq_lens_decoder=share_inputs["seq_lens_decoder"], + seq_lens_this_time=share_inputs["seq_lens_this_time"], + block_tables=share_inputs["block_tables"], + block_groups=share_inputs["block_groups"], + block_list=share_inputs["block_list"], + block_indices=share_inputs["block_indices"], + block_offsets=share_inputs["block_offsets"], + block_mapping=share_inputs["block_mapping"], + attention_mask=share_inputs["block_bias"], + block_size=share_inputs["block_size"], + total_batch=share_inputs["total_batch"], + batch_ids=share_inputs["batch_ids"], + is_prompt=share_inputs["is_prompt"], + attn_backend=attn_backend, + rotary_embs=share_inputs["rotary_embs"], + caches=share_inputs["caches"], + ) + return ret + + def clear_caches(self): + """safe clear caches""" + if self.caches: + del self.caches diff --git a/fastdeploy/model_executor/layers/activation.py b/fastdeploy/model_executor/layers/activation.py index 2eb800de6..79fd3b24f 100644 --- a/fastdeploy/model_executor/layers/activation.py +++ b/fastdeploy/model_executor/layers/activation.py @@ -72,6 +72,8 @@ class SiluAndMul(nn.Layer): self.forward = self.forward_cuda elif current_platform.is_gcu(): self.forward = self.forward_gcu + elif current_platform.is_intel_hpu(): + self.forward = self.forward_intel_hpu else: raise NotImplementedError @@ -147,6 +149,16 @@ class SiluAndMul(nn.Layer): out = out + self.bias return out + def forward_intel_hpu(self, x): + """ + Forward propagation of the custom activation layer. + Args: + x (Tensor): Input tensor to the activation layer. + Returns: + Tensor: Output tensor. + """ + return + def get_act_fn(act_fn_name: str) -> nn.Layer: """Get an activation function by name.""" diff --git a/fastdeploy/model_executor/layers/backends/__init__.py b/fastdeploy/model_executor/layers/backends/__init__.py index ddbe410d1..faa4e66f7 100644 --- a/fastdeploy/model_executor/layers/backends/__init__.py +++ b/fastdeploy/model_executor/layers/backends/__init__.py @@ -55,3 +55,10 @@ if current_platform.is_maca(): if hasattr(metax, "__all__"): globals().update({name: getattr(metax, name) for name in metax.__all__}) __all__.extend(metax.__all__) + +if current_platform.is_intel_hpu(): + from . import intel_hpu + + if hasattr(intel_hpu, "__all__"): + globals().update({name: getattr(intel_hpu, name) for name in intel_hpu.__all__}) + __all__.extend(intel_hpu.__all__) diff --git a/fastdeploy/model_executor/layers/backends/intel_hpu/__init__.py b/fastdeploy/model_executor/layers/backends/intel_hpu/__init__.py new file mode 100644 index 000000000..e3bd54727 --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/intel_hpu/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +intel_hpu backend methods +""" + +from .attention.hpu_attn_backend import HPUAttentionBackend +from .moe.fused_moe_hpu_backend import HpuMoEMethod, HpuTensorWiseFP8MoEMethod + +__all__ = [ + "HPUAttentionBackend", + "HpuMoEMethod", + "HpuTensorWiseFP8MoEMethod", +] diff --git a/fastdeploy/model_executor/layers/backends/intel_hpu/attention/__init__.py b/fastdeploy/model_executor/layers/backends/intel_hpu/attention/__init__.py new file mode 100644 index 000000000..01de18669 --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/intel_hpu/attention/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .hpu_attn_backend import HPUAttentionBackend + +__all__ = [ + "HPUAttentionBackend", +] diff --git a/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py b/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py new file mode 100644 index 000000000..962b6e113 --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py @@ -0,0 +1,314 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +import os +from abc import abstractmethod +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional + +import paddle + +if TYPE_CHECKING: + from paddle._typing.dtype_like import _DTypeLiteral + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.layers.attention.attention import Attention +from fastdeploy.model_executor.layers.attention.base_attention_backend import ( + AttentionBackend, + AttentionMetadata, +) + +if TYPE_CHECKING: + from fastdeploy.model_executor.forward_meta import HPUForwardMeta + +from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear + + +class AttentionBackend_HPU(AttentionBackend): + """The base class of attention backends""" + + @abstractmethod + def init_attention_metadata(self, forward_meta: HPUForwardMeta): + """Initialize the forward metadata.""" + raise NotImplementedError() + + def forward( + self, + src: paddle.Tensor, + qkv_proj: QKVParallelLinear, + o_proj: RowParallelLinear, + layer: paddle.nn.Layer, + forward_meta: HPUForwardMeta, + ): + """ + Run a forward. + args: + src: the hidden states tensor + residual_input: the residual tensor + layer: The layer that will be used for the forward. + forward_meta: The forward metadata. + """ + if forward_meta.forward_mode.is_mixed(): + return self.forward_mixed( + src, + qkv_proj, + o_proj, + layer, + forward_meta, + ) + elif forward_meta.forward_mode.is_decode(): + return self.forward_decode( + src, + qkv_proj, + o_proj, + layer, + forward_meta, + ) + else: + return self.forward_extend( + src, + qkv_proj, + o_proj, + layer, + forward_meta, + ) + + def forward_mixed( + self, + src: paddle.Tensor, + qkv_proj: QKVParallelLinear, + o_proj: RowParallelLinear, + layer: paddle.nn.Layer, + forward_meta: HPUForwardMeta, + ): + """Run a forward for mix.""" + raise NotImplementedError() + + def forward_decode( + self, + src: paddle.Tensor, + qkv_proj: QKVParallelLinear, + o_proj: RowParallelLinear, + layer: paddle.nn.Layer, + forward_meta: HPUForwardMeta, + ): + """Run a forward for decode.""" + raise NotImplementedError() + + def forward_extend( + self, + src: paddle.Tensor, + qkv_proj: QKVParallelLinear, + o_proj: RowParallelLinear, + layer: paddle.nn.Layer, + forward_meta: HPUForwardMeta, + ): + """Run a forward for extend.""" + raise NotImplementedError() + + +@dataclass +class HPUAttentionMetadata(AttentionMetadata): + """ + HPUAttentionMetadata + """ + + max_len_kv: paddle.Tensor = None + set_max_lengths: int = -1 + encoder_batch_ids: paddle.Tensor = None + encoder_tile_ids_per_batch: paddle.Tensor = None + encoder_num_blocks: paddle.Tensor = None + kv_batch_ids: paddle.Tensor = None + kv_tile_ids_per_batch: paddle.Tensor = None + kv_num_blocks: paddle.Tensor = None + decoder_batch_ids: paddle.Tensor = None + decoder_tile_ids_per_batch: paddle.Tensor = None + decoder_num_blocks: paddle.Tensor = None + + _dtype: _DTypeLiteral = paddle.bfloat16 + encoder_max_partition_size: int = 32768 + max_partition_size: int = 32768 + block_tables: Optional[paddle.Tensor] = None + rotary_embs: Optional[paddle.Tensor] = None + attn_mask: Optional[paddle.Tensor] = None + encoder_block_shape_q: Optional[paddle.Tensor] = None + decoder_block_shape_q: Optional[paddle.Tensor] = None + _fuse_kernel_compute_dtype: str = "bf16" + + # pd_disaggregation + kv_signal_metadata: Optional[paddle.Tensor] = None + kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list) + + +class HPUAttentionBackend(AttentionBackend_HPU): + """ + HPUAttentionBackend backend implementation. + """ + + def __init__(self, llm_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int): + """ + HPUAttentionBackend __init__ + """ + super().__init__() + self.attention_metadata: HPUAttentionMetadata = None + # TODO(gongshaotian): Use llm_config parameters in the correct location + self.block_size = llm_config.parallel_config.block_size + self.max_seq_len = llm_config.parallel_config.max_model_len + self.rope_theta = 10000.0 if llm_config.model_config.rope_theta is None else llm_config.model_config.rope_theta + self.rope_3d = getattr(llm_config.model_config, "rope_3d", False) + self.causal = getattr(llm_config.model_config, "causal", True) + self.speculative_method: str = llm_config.speculative_config.method + self.use_speculate: bool = self.speculative_method is not None + self.speculate_max_draft_token_num: int = llm_config.speculative_config.num_speculative_tokens + self.keep_pd_step_flag: bool = llm_config.speculative_config.model_type == "mtp" + self.rank: int = llm_config.parallel_config.tensor_parallel_rank + self.nranks = llm_config.parallel_config.tensor_parallel_size + + self.kv_num_heads = kv_num_heads + self.num_heads = num_heads + self.head_dim = head_dim + self.num_layers = llm_config.model_config.num_hidden_layers + + # pd_disaggregation + self.use_pd_disaggregation = int(os.getenv("FLAGS_use_pd_disaggregation", 0)) + self.start_layer_index = llm_config.model_config.start_layer_index + + def init_attention_metadata(self, forward_meta): + """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" + metadata = HPUAttentionMetadata() + metadata.encoder_block_shape_q = 64 + metadata.decoder_block_shape_q = 16 + metadata.max_partition_size = 32768 + metadata.encoder_max_partition_size = 32768 + metadata._dtype = paddle.get_default_dtype() + if metadata._dtype == "bfloat16": + metadata._fuse_kernel_compute_dtype = "bf16" + elif metadata._dtype == "float16": + metadata._fuse_kernel_compute_dtype = "fp16" + elif metadata._dtype == "float32": + metadata._fuse_kernel_compute_dtype = "fp32" + metadata.block_tables = forward_meta.block_tables + metadata.rotary_embs = forward_meta.rotary_embs + metadata.attn_mask = forward_meta.attn_mask + + # pd_disaggregation + metadata.kv_signal_data_list = [None] * self.num_layers + self.attention_metadata = metadata + + def get_kv_cache_shape( + self, + max_num_blocks: int, + ): + """ + Caculate kv cache shape + """ + return (max_num_blocks, self.block_size, self.kv_num_heads, self.head_dim) + + def forward_extend( + self, src, qkv_proj: QKVParallelLinear, o_proj: RowParallelLinear, layer: Attention, forward_meta + ): + """ + forward_extend + """ + # metadata = self.attention_metadata + + from fastdeploy.model_executor.ops.intel_hpu import ( + fused_qkv_rope, + fused_sdpa_proj_t, + index_copy_, + ) + + query_states, key_value_states = fused_qkv_rope( + src, + qkv_proj.weight, + qkv_proj.bias, + forward_meta.rotary_embs, + self.head_dim, + self.num_heads, + forward_meta.total_batch, + transpose=False, + use_neox_style=layer.use_neox_rotary_style, + ) + + kv, B, BP_BS, M, H = key_value_states.shape + key_value_states_reshape = key_value_states.reshape([kv, -1, forward_meta.block_size, M, H]) + key_states = key_value_states_reshape[0] + value_states = key_value_states_reshape[1] + k_cache = forward_meta.caches[2 * layer.layer_id] + v_cache = forward_meta.caches[2 * layer.layer_id + 1] + index_copy_(k_cache, forward_meta.block_indices, key_states, 0) + index_copy_(v_cache, forward_meta.block_indices, value_states, 0) + + out_linear_out = fused_sdpa_proj_t( + query_states, + key_value_states, + forward_meta.attn_mask, + None, + o_proj.weight, + scaling_factor=self.head_dim**-0.5, + causal=True, + softmax_mode=0, + ) + + if self.nranks > 1: + from fastdeploy.distributed.communication import ( + tensor_model_parallel_all_reduce_custom, + ) + + tensor_model_parallel_all_reduce_custom(out_linear_out) + + return out_linear_out + + def forward_decode( + self, src, qkv_proj: QKVParallelLinear, o_proj: RowParallelLinear, layer: Attention, forward_meta + ): + """ + forward_decode + """ + # metadata = self.attention_metadata + from fastdeploy.model_executor.ops.intel_hpu import fused_block_attention + + res = fused_block_attention( + src, + forward_meta.rotary_embs, + forward_meta.caches[2 * layer.layer_id], + forward_meta.caches[2 * layer.layer_id + 1], + forward_meta.block_groups, + forward_meta.block_list, + forward_meta.block_mapping, + forward_meta.attention_mask, + forward_meta.block_indices, + forward_meta.block_offsets, + qkv_proj.weight, + qkv_proj.bias, + o_proj.weight, + self.head_dim, + self.num_heads, + scaling_factor=self.head_dim**-0.5, + transpose=False, + use_neox_style=layer.use_neox_rotary_style, + ) + + # all_reduce + if self.nranks > 1: + from fastdeploy.distributed.communication import ( + tensor_model_parallel_all_reduce_custom, + ) + + tensor_model_parallel_all_reduce_custom(res) + return res diff --git a/fastdeploy/model_executor/layers/backends/intel_hpu/moe/__init__.py b/fastdeploy/model_executor/layers/backends/intel_hpu/moe/__init__.py new file mode 100644 index 000000000..3e50faf7a --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/intel_hpu/moe/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" " +intel_hpu moe +""" diff --git a/fastdeploy/model_executor/layers/backends/intel_hpu/moe/fused_moe_hpu_backend.py b/fastdeploy/model_executor/layers/backends/intel_hpu/moe/fused_moe_hpu_backend.py new file mode 100644 index 000000000..fc350c055 --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/intel_hpu/moe/fused_moe_hpu_backend.py @@ -0,0 +1,249 @@ +""" +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle +from paddle import nn + +from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce_custom +from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase + + +class HpuMoEMethod(MoEMethodBase): + """ + Use Cutlass Group Gemm to compute Fused MoE. + This method is the oldest way to compute MoE in Paddle. + """ + + def create_weights(self, layer: nn.Layer, **extra_weight_attrs): + # TODO: split create_parameter from process_loaded_weights + return NotImplemented + + def process_loaded_weights(self, layer: nn.Layer, state_dict): + """ + Paddle HPU load weight process. + """ + # bf16 + up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict) + + for idx, weights_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): + weights_list = [] + for i in range(layer.num_local_experts): + weight_tensor = weights_tensor[i] + weight = layer.create_parameter( + shape=weight_tensor.shape, + dtype=weight_tensor.dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) + weight.set_value(weight_tensor) + weights_list.append(weight) + weights_name = self.added_weight_attrs[idx] + setattr(layer, weights_name, weights_list) + + def apply_ep_prefill( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate_out: paddle.Tensor, + ) -> paddle.Tensor: + """ + Apply the EP prefill method. + """ + raise NotImplementedError + + def apply_ep_decode( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate_out: paddle.Tensor, + ) -> paddle.Tensor: + """ + Apply the EP decoder method. + """ + raise NotImplementedError + + def apply_tp( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + ) -> paddle.Tensor: + """ + Paddle hpu Fused MoE. + """ + if layer.topk_method == "noaux_tc": + raise NotImplementedError + + # norm_topk_prob = False if layer.topk_method == "noaux_tc" else True + """ + weights = paddle.nn.functional.softmax(gate_out, axis=-1) + if layer.moe_use_gate_correction_bias: + scores = weights + layer.gate_correction_bias + _, selected_experts = paddle.topk(scores, layer.top_k, axis=-1) + routing_weights = paddle.index_sample(weights, selected_experts) + else: + routing_weights, selected_experts = paddle.topk(weights, layer.top_k, axis=-1) + routing_weights /= paddle.sum(routing_weights, axis=-1, keepdim=True) + + common_inputs = (x, selected_experts, routing_weights.cast("bfloat16")) + + common_params = ( + False, #permuted_weights + "silu", #activation, + 0, + layer.num_experts - 1, + ) + + weights = ( + layer.moe_ffn1_weight, + layer.moe_ffn2_weight, + ) + + fused_moe_out, _ = mixture_of_experts( + *common_inputs, *weights, *common_params, False + ) + + # if norm_topk_prob: + # routing_weights_norm = paddle.sum(routing_weights, axis=-1, keepdim=True).cast("bfloat16") + # fused_moe_out = fused_moe_out / routing_weights_norm + """ + chunk_size = 64 + from fastdeploy.model_executor.ops.intel_hpu import fused_gate_moe + + # TODO: fuse matmul to gate_moe + gate_out = paddle.matmul(x.cast("float32"), gate.weight) + fused_moe_out = fused_gate_moe( + x, + gate_out, + layer.gate_correction_bias, + layer.up_gate_proj_weight, + layer.down_proj_weight, + layer.top_k, + layer.moe_use_gate_correction_bias, + norm_topk_prob=True, + permuted_weights=False, + activation="silu", + experts_min=layer.expert_id_offset, + experts_max=layer.expert_id_offset + layer.num_local_experts - 1, + chunk_size=chunk_size, + ) + if layer.reduce_results and layer.tp_size > 1: + tensor_model_parallel_all_reduce_custom(fused_moe_out) + + return fused_moe_out + + +class HpuTensorWiseFP8MoEMethod(HpuMoEMethod): + """ + Use Cutlass Group Gemm to compute Fused MoE. + This method is the oldest way to compute MoE in Paddle. + """ + + def create_weights(self, layer: nn.Layer, **extra_weight_attrs): + # TODO: split create_parameter from process_loaded_weights + return NotImplemented + + def process_loaded_weights(self, layer: nn.Layer, state_dict): + """ + Paddle HPU load weight process. + """ + # bf16 + up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict) + + from fastdeploy.model_executor.ops.intel_hpu import fused_quant + + self.quant_fn = fused_quant + self.moe_quant_type = "tensor_wise_fp8" + + for idx, weights_tensor in enumerate([up_gate_proj_weights, down_proj_weights]): + weights_name = self.added_weight_attrs[idx] + scales_name = self.added_scale_attrs[idx] + + weights_list = [] + scales_list = [] + + for i in range(layer.num_local_experts): + # quantize loaded weights + quant_weight, scale = self.quant_fn(weights_tensor[i]) + weights_list.append(quant_weight) + scales_list.append(scale) + + setattr(layer, weights_name, weights_list) + setattr(layer, scales_name, scales_list) + + def apply_ep_prefill( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate_out: paddle.Tensor, + ) -> paddle.Tensor: + """ + Apply the EP prefill method. + """ + raise NotImplementedError + + def apply_ep_decode( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate_out: paddle.Tensor, + ) -> paddle.Tensor: + """ + Apply the EP decoder method. + """ + raise NotImplementedError + + def apply_tp( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + ) -> paddle.Tensor: + """ + Paddle hpu Fused MoE. + """ + if layer.topk_method == "noaux_tc": + raise NotImplementedError + + # norm_topk_prob = False if layer.topk_method == "noaux_tc" else True + + chunk_size = 64 + from fastdeploy.model_executor.ops.intel_hpu import fused_gate_moe_fp8 + + # TODO: fuse matmul to gate_moe + gate_out = paddle.matmul(x.cast("float32"), gate.weight) + fused_moe_out = fused_gate_moe_fp8( + x, + gate_out, + layer.gate_correction_bias, + layer.up_gate_proj_weight, + layer.down_proj_weight, + None, # intermediate_hidden_states_scales + layer.up_gate_proj_weight_scale, + layer.down_proj_weight_scale, + layer.top_k, + layer.moe_use_gate_correction_bias, + norm_topk_prob=True, + permuted_weights=False, + activation="silu", + experts_min=layer.expert_id_offset, + experts_max=layer.expert_id_offset + layer.num_local_experts - 1, + chunk_size=chunk_size, + ) + + if layer.reduce_results and layer.tp_size > 1: + tensor_model_parallel_all_reduce_custom(fused_moe_out) + + return fused_moe_out diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index d17ab1be3..ff9c16a3e 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -116,6 +116,7 @@ class LinearBase(nn.Layer): or current_platform.is_gcu() or current_platform.is_dcu() or current_platform.is_maca() + or current_platform.is_intel_hpu() ): self.forward = self.forward_cuda else: diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index ddd7a4aea..58c87cf33 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -56,6 +56,11 @@ def get_moe_method(): ) return MetaxTritonWeightOnlyMoEMethod(None) + elif current_platform.is_intel_hpu(): + from fastdeploy.model_executor.layers.backends import HpuMoEMethod + + return HpuMoEMethod(None) + # return HpuTensorWiseFP8MoEMethod(None) raise NotImplementedError @@ -139,6 +144,7 @@ class FusedMoE(nn.Layer): self.hidden_size = fd_config.model_config.hidden_size self.num_experts = num_experts + self.num_local_experts = self.num_experts // self.ep_size self.moe_intermediate_size = moe_intermediate_size // self.tp_size diff --git a/fastdeploy/model_executor/layers/rotary_embedding.py b/fastdeploy/model_executor/layers/rotary_embedding.py index a51b53b0d..eb1d65695 100644 --- a/fastdeploy/model_executor/layers/rotary_embedding.py +++ b/fastdeploy/model_executor/layers/rotary_embedding.py @@ -69,6 +69,12 @@ class ErnieRotaryEmbedding: .transpose([0, 1, 2, 4, 3]) .reshape([2, bsz, max_seq_len, 1, self.rotary_dim]) ) + if paddle.is_compiled_with_custom_device("intel_hpu"): + return ( + paddle.concat([rot_emb, rot_emb], axis=3) + .transpose([0, 1, 2, 4, 3]) + .reshape([2, bsz, max_seq_len, 1, self.rotary_dim]) + ) else: return rot_emb diff --git a/fastdeploy/model_executor/layers/sample/meta_data.py b/fastdeploy/model_executor/layers/sample/meta_data.py index 03cdf24c2..41dc6e117 100644 --- a/fastdeploy/model_executor/layers/sample/meta_data.py +++ b/fastdeploy/model_executor/layers/sample/meta_data.py @@ -54,3 +54,6 @@ class SamplingMetadata: temp_scaled_logprobs: Optional[paddle.Tensor] = None top_p_normalized_logprobs: Optional[paddle.Tensor] = None share_inputs: Optional[Dict[str, paddle.Tensor]] = None + # Add for HPU post-processing + seq_lens_encoder: Optional[paddle.Tensor] = None + seq_lens_decoder: Optional[paddle.Tensor] = None diff --git a/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py b/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py index e66db93ba..04a8ab102 100644 --- a/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py +++ b/fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py @@ -136,6 +136,23 @@ def apply_penalty_multi_scores( min_dec_lens, eos_token_ids, ) + elif current_platform.is_intel_hpu(): + from fastdeploy.model_executor.ops.intel_hpu import ( + get_token_penalty_multi_scores, + ) + + logits = get_token_penalty_multi_scores( + pre_token_ids, + logits, + repetition_penalties, + frequency_penalties, + presence_penalties, + temperature, + bad_words_token_ids, + step_idx, + min_dec_lens, + eos_token_ids, + ) else: raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 60d2e663b..334dcc80f 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -209,6 +209,8 @@ class Sampler(nn.Layer): or current_platform.is_maca() ): self.forward = self.forward_cuda + elif current_platform.is_intel_hpu(): + self.forward = self.forward_intel_hpu else: raise NotImplementedError @@ -377,6 +379,49 @@ class Sampler(nn.Layer): return sampler_output + def forward_intel_hpu( + self, + logits: paddle.Tensor, + sampling_metadata: SamplingMetadata, + batch_ids: paddle.Tensor, + max_batch: int, + rank: int, + local_rank: int, + ) -> paddle.Tensor: + if logits.dtype != paddle.float32: + logits = paddle.cast(logits, paddle.float32) + + from fastdeploy.model_executor.ops.intel_hpu import fused_sampler + + _, next_tokens = fused_sampler( + sampling_metadata.pre_token_ids, + sampling_metadata.prompt_ids, + sampling_metadata.seq_lens_encoder, + sampling_metadata.seq_lens_decoder, + sampling_metadata.step_idx, + sampling_metadata.stop_flags, + logits, + sampling_metadata.repetition_penalties, + sampling_metadata.frequency_penalties, + sampling_metadata.presence_penalties, + sampling_metadata.temperature, + sampling_metadata.bad_words_token_ids, + sampling_metadata.step_idx, + sampling_metadata.min_dec_lens, + sampling_metadata.eos_token_ids, + sampling_metadata.top_p, + rank, + local_rank, + ) + + if next_tokens.shape[0] != max_batch: + dim = next_tokens.shape[-1] + tmp_tokens = paddle.full((max_batch, dim), -1, dtype=next_tokens.dtype) + tmp_tokens = paddle.scatter(tmp_tokens, batch_ids, next_tokens[: batch_ids.shape[0], :]) + return tmp_tokens + + return next_tokens + class SpeculativeSampler(nn.Layer): """ diff --git a/fastdeploy/model_executor/ops/__init__.py b/fastdeploy/model_executor/ops/__init__.py index 5e30570c9..0519764f1 100644 --- a/fastdeploy/model_executor/ops/__init__.py +++ b/fastdeploy/model_executor/ops/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """fastdeploy module""" -from . import cpu, gcu, gpu, iluvatar, npu, xpu +from . import cpu, gcu, gpu, iluvatar, intel_hpu, npu, xpu -__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar", "gcu"] +__all__ = ["gpu", "cpu", "xpu", "npu", "iluvatar", "gcu", "intel_hpu"] diff --git a/fastdeploy/model_executor/ops/intel_hpu/__init__.py b/fastdeploy/model_executor/ops/intel_hpu/__init__.py new file mode 100644 index 000000000..0014d52ab --- /dev/null +++ b/fastdeploy/model_executor/ops/intel_hpu/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""fastdeploy intel_hpu ops.""" + +from fastdeploy.import_ops import import_custom_ops + +# PACKAGE = "fastdeploy.model_executor.ops.intel_hpu" +PACKAGE = "paddlenlp_ops" + +import_custom_ops(PACKAGE, "paddlenlp_ops", globals()) diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 01cc699cb..cde323336 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -56,6 +56,8 @@ elif current_platform.is_maca(): update_inputs, update_inputs_v1, ) +elif current_platform.is_intel_hpu(): + pass else: from fastdeploy.model_executor.ops.gpu import ( get_padding_offset, diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index b91086cd3..10c788bc6 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -284,6 +284,8 @@ class TokenProcessor: from fastdeploy.model_executor.ops.iluvatar import get_output elif current_platform.is_gcu(): from fastdeploy.model_executor.ops.gcu import get_output + elif current_platform.is_intel_hpu(): + from fastdeploy.model_executor.ops.intel_hpu import get_output else: from fastdeploy.model_executor.ops.gpu import ( get_output, diff --git a/fastdeploy/platforms/__init__.py b/fastdeploy/platforms/__init__.py index adf5a3ad7..8ba1c7b13 100644 --- a/fastdeploy/platforms/__init__.py +++ b/fastdeploy/platforms/__init__.py @@ -23,6 +23,7 @@ from .cuda import CUDAPlatform from .dcu import DCUPlatform from .gcu import GCUPlatform from .iluvatar import IluvatarPlatform +from .intel_hpu import INTEL_HPUPlatform from .maca import MACAPlatform from .npu import NPUPlatform from .xpu import XPUPlatform @@ -43,6 +44,8 @@ def __getattr__(name: str): _current_platform = XPUPlatform() elif paddle.is_compiled_with_custom_device("npu"): _current_platform = NPUPlatform() + elif paddle.is_compiled_with_custom_device("intel_hpu"): + _current_platform = INTEL_HPUPlatform() elif paddle.is_compiled_with_custom_device("iluvatar_gpu"): _current_platform = IluvatarPlatform() elif paddle.is_compiled_with_custom_device("gcu"): diff --git a/fastdeploy/platforms/base.py b/fastdeploy/platforms/base.py index 478bb7b62..16251c1c1 100644 --- a/fastdeploy/platforms/base.py +++ b/fastdeploy/platforms/base.py @@ -27,6 +27,7 @@ class _Backend(enum.Enum): FLASH_ATTN = enum.auto() BLOCK_ATTN = enum.auto() PLAS_ATTN = enum.auto() + HPU_ATTN = enum.auto() class Platform: @@ -54,6 +55,12 @@ class Platform: """ return paddle.is_compiled_with_xpu() + def is_intel_hpu(self) -> bool: + """ + whether platform is intel_hpu + """ + return paddle.is_compiled_with_custom_device("intel_hpu") + def is_cpu(self) -> bool: """ whether platform is cpu diff --git a/fastdeploy/platforms/intel_hpu.py b/fastdeploy/platforms/intel_hpu.py new file mode 100644 index 000000000..f63f18928 --- /dev/null +++ b/fastdeploy/platforms/intel_hpu.py @@ -0,0 +1,52 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle + +from fastdeploy.utils import console_logger as logger + +from .base import Platform, _Backend + + +class INTEL_HPUPlatform(Platform): + device_name = "intel_hpu" + + @classmethod + def available(self): + """ + Check whether Intel HPU is available. + """ + try: + assert paddle.base.core.get_custom_device_count("intel_hpu") > 0 + return True + except Exception as e: + logger.warning( + "You are using Intel HPU platform, but there is no Intel HPU " + "detected on your machine. Maybe Intel HPU devices is not set properly." + f"\n Original Error is {e}" + ) + return False + + @classmethod + def get_attention_backend_cls(cls, selected_backend): + """ + get_attention_backend_cls + """ + if selected_backend == _Backend.NATIVE_ATTN: + logger.info("Using NATIVE ATTN backend.") + return "fastdeploy.model_executor.layers.attention.PaddleNativeAttnBackend" + elif selected_backend == _Backend.HPU_ATTN: + logger.info("Using HPU ATTN backend.") + return "fastdeploy.model_executor.layers.backends.intel_hpu.attention.HPUAttentionBackend" + else: + logger.warning("Other backends are not supported for now.") diff --git a/fastdeploy/worker/hpu_model_runner.py b/fastdeploy/worker/hpu_model_runner.py new file mode 100644 index 000000000..c21006deb --- /dev/null +++ b/fastdeploy/worker/hpu_model_runner.py @@ -0,0 +1,1463 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import time +from typing import Dict, List, Optional + +import numpy as np +import paddle +import paddle.nn as nn +from paddleformers.utils.log import logger + +from fastdeploy.config import FDConfig +from fastdeploy.engine.request import Request + +# from fastdeploy.spec_decode import MTPProposer, NgramProposer +from fastdeploy.model_executor.forward_meta import HPUForwardMeta +from fastdeploy.model_executor.guided_decoding import get_guided_backend +from fastdeploy.model_executor.guided_decoding.base_guided_decoding import ( + LogitsProcessorBase, +) +from fastdeploy.model_executor.layers.attention import get_attention_backend +from fastdeploy.model_executor.layers.attention.base_attention_backend import ( + AttentionBackend, +) +from fastdeploy.model_executor.layers.rotary_embedding import get_rope +from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata +from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler +from fastdeploy.model_executor.model_loader import get_model_loader +from fastdeploy.model_executor.ops.intel_hpu import ( + recover_block, + save_output, + step_paddle, + update_inputs_v3, +) +from fastdeploy.utils import get_logger +from fastdeploy.worker.model_runner_base import ModelRunnerBase +from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput + +hpu_model_runner_profile_logger = get_logger("hpu_model_runner_profile", "hpu_model_runner_profile.log") + + +def post_process_hpu(sampled_token_ids: paddle.Tensor, model_output: ModelOutputData, is_warmuping: bool) -> None: + """Post-processing steps after completing a single token generation.""" + start_time = time.time() + + not_need_stop_hpu = model_output.not_need_stop.to(sampled_token_ids.place) + is_block_step_hpu = model_output.is_block_step.to(sampled_token_ids.place) + + update_inputs_v3( + model_output.stop_flags, + model_output.step_idx, + not_need_stop_hpu, + model_output.seq_lens_this_time, + model_output.seq_lens_encoder, + model_output.seq_lens_decoder, + model_output.max_dec_len, + model_output.input_ids, + model_output.stop_nums, + sampled_token_ids, + is_block_step_hpu, + model_output.eos_token_id, + model_output.next_tokens, + ) + + model_output.not_need_stop[:] = not_need_stop_hpu.cpu() + model_output.is_block_step[:] = is_block_step_hpu.cpu() + + end_time = time.time() + execution_time = (end_time - start_time) * 1000 + hpu_model_runner_profile_logger.info(f"post_process_hpu::update_inputs_v3 execution time(ms): {execution_time}") + + if is_warmuping: + return + start_time = time.time() + save_output( + sampled_token_ids, + model_output.not_need_stop, + model_output.mp_rank, + ) + end_time = time.time() + execution_time = (end_time - start_time) * 1000 + hpu_model_runner_profile_logger.info(f"post_process_hpu::save_output execution time(ms): {execution_time}") + + +def recover_block_hpu( + recover_block_list, # cpu + recover_len, # cpu + stop_flags, # hpu + seq_lens_this_time, # hpu + ori_seq_lens_encoder, # cpu + seq_lens_encoder, # hpu + block_tables, # cpu + free_list, # cpu + free_list_len, # cpu + input_ids, # hpu + pre_ids, # hpu + step_idx, # hpu + encoder_block_lens, # cpu + used_list_len, # cpu + next_tokens, # hpu + first_token_ids, +): # hpu + + for bid in range(recover_len.item()): + recover_id = recover_block_list[bid].item() + ori_seq_len_encoder = ori_seq_lens_encoder[recover_id].item() + step_idx_now = step_idx[recover_id].item() + seq_len = ori_seq_len_encoder + step_idx_now + encoder_block_len = encoder_block_lens[recover_id].item() + decoder_used_len = used_list_len[recover_id].item() + + seq_lens_this_time[recover_id] = seq_len + seq_lens_encoder[recover_id] = seq_len + stop_flags[recover_id] = False + + ori_free_list_len = free_list_len[0] + free_list_len[0] -= decoder_used_len + + for i in range(decoder_used_len): + block_tables[recover_id, encoder_block_len + i] = free_list[ori_free_list_len - i - 1] + + recover_block(input_ids, first_token_ids, pre_ids, next_tokens, recover_id, ori_seq_len_encoder, step_idx_now) + + +def step_intel_hpu(share_inputs: Dict[str, paddle.Tensor], block_size: int, max_model_len: int) -> None: + """ + step cuda + """ + step_paddle( + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["block_tables"], + share_inputs["encoder_block_lens"], + share_inputs["is_block_step"], + share_inputs["step_block_list"], + share_inputs["step_lens"], + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["need_block_list"], + share_inputs["need_block_len"], + share_inputs["used_list_len"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["first_token_ids"], + block_size, + max_model_len, + ) + if share_inputs["recover_lens"].item() > 0: + recover_block_hpu( + share_inputs["recover_block_list"], + share_inputs["recover_lens"], + share_inputs["stop_flags"], + share_inputs["seq_lens_this_time"], + share_inputs["ori_seq_lens_encoder"], + share_inputs["seq_lens_encoder"], + share_inputs["block_tables"], + share_inputs["free_list"], + share_inputs["free_list_len"], + share_inputs["input_ids"], + share_inputs["pre_ids"], + share_inputs["step_idx"], + share_inputs["encoder_block_lens"], + share_inputs["used_list_len"], + share_inputs["next_tokens"], + share_inputs["first_token_ids"], + ) + share_inputs["recover_lens"] = paddle.full([1], 0, dtype="int32").cpu() + + +# TODO: replace rebuild_padding_v3 in CustomDevice if we adopt this version pp optimization +def rebuild_padding_v3_1( + tmp_out, + batch_ids, + total_batch, + seq_lens_encoder, + is_prompt=None, +): + dim_emb = tmp_out.shape[-1] + output_data = paddle.zeros((total_batch, dim_emb)) + if is_prompt is True: # context + tmp_out = tmp_out.reshape([total_batch, -1, dim_emb]) + for i in range(batch_ids.shape[0]): + seq_len = seq_lens_encoder[batch_ids[i]].item() + output_data[i] = tmp_out[i, seq_len - 1] + elif is_prompt is False: + output_data[0 : batch_ids.shape[0], :] = tmp_out[: batch_ids.shape[0], :] + + return output_data + + +from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear +from fastdeploy.model_executor.ops.intel_hpu import fused_mlp + + +def fused_attention_forward( + self, + src: paddle.Tensor = None, + qkv_proj: QKVParallelLinear = None, + o_proj: RowParallelLinear = None, + forward_meta: HPUForwardMeta = None, +): + """ + The forward function of attention layer. + args: + src: the hidden states tensor + residual_input: the residual tensor + forward_meta: the forward meta data + """ + return forward_meta.attn_backend.forward( + src, + qkv_proj, + o_proj, + self, + forward_meta, + ) + + +def fused_self_atten_forward( + self, + forward_meta: HPUForwardMeta, + hidden_states: paddle.Tensor, +): + """ """ + atten_out = self.attn( + src=hidden_states, + qkv_proj=self.qkv_proj, + o_proj=self.o_proj, + forward_meta=forward_meta, + ) + + return atten_out + + +def fused_mlp_forward(self, x): + """ """ + out = fused_mlp( + x, + self.up_gate_proj.weight, + None, + self.down_proj.weight, + ) + + # all_reduce + if self.nranks > 1: + from fastdeploy.distributed.communication import ( + tensor_model_parallel_all_reduce_custom, + ) + + tensor_model_parallel_all_reduce_custom(out) + + return out + + +import types + +from fastdeploy.model_executor.layers.attention.attention import Attention +from fastdeploy.model_executor.models.ernie4_5_moe import ( + Ernie4_5_Attention, + Ernie4_5_MLP, +) +from fastdeploy.model_executor.models.qwen2 import Qwen2Attention, Qwen2MLP + + +def convert_model(model): + """ """ + for name, module in model.named_children(): + if len(list(module.named_children())) > 0: + # print(f"********** model {model.__class__.__name__} has submodule: name={name}, module={module.__class__.__name__}") + if isinstance(module, Ernie4_5_Attention): + module.forward = types.MethodType(fused_self_atten_forward, module) + if isinstance(module, Qwen2Attention): + module.forward = types.MethodType(fused_self_atten_forward, module) + if isinstance(module, Ernie4_5_MLP): + module.forward = types.MethodType(fused_mlp_forward, module) + if isinstance(module, Qwen2MLP): + module.forward = types.MethodType(fused_mlp_forward, module) + convert_model(module) + else: + # print(f"*********[ Leaf node] Loading submodule: name={name} -- module: {module.__class__.__name__}") + if isinstance(module, Attention): + module.forward = types.MethodType(fused_attention_forward, module) + + return model + + +class HPUModelRunner(ModelRunnerBase): + """ """ + + def __init__( + self, + fd_config: FDConfig, + device: str, # logic device + device_id: int, # physical device id + rank: int, + local_rank: int, + ): + super().__init__(fd_config=fd_config, device=device) + self.rank = rank + self.local_rank = local_rank + self.device_id = device_id + self.speculative_method = self.fd_config.speculative_config.method + self.speculative_decoding = self.speculative_method is not None + + self.guided_backend = None + if self.fd_config.parallel_config.guided_decoding_backend != "off": + self.guided_backend = get_guided_backend(fd_config=self.fd_config) + + # Sampler + if not self.speculative_decoding: + self.sampler = Sampler() + else: + self.sampler = SpeculativeSampler(fd_config) + + # Lazy initialize kv cache after model loading + # self.kv_caches: list[paddle.Tensor] = [] + + # Cuda Graph + self.use_cudagraph = self.graph_opt_config.use_cudagraph + self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) + self.cudagraph_num_of_warmups = self.graph_opt_config.cudagraph_num_of_warmups + self.input_ids = paddle.zeros(self.scheduler_config.max_num_seqs, dtype="int32") + + # Initialize share inputs + self._init_share_inputs(self.scheduler_config.max_num_seqs) + self.infer_seed_increment = paddle.full( + shape=[self.scheduler_config.max_num_seqs, 1], fill_value=4, dtype="int64" + ).cpu() + self.restore_chunked_prefill_request = dict() + + # Initialize attention Backend + # Note(gonshaotian): Currently, all attention layers share one attention backend instance. + # In the future, we will expand it as a list. + self.attn_backends: list[AttentionBackend] = [] + # self.attn_metadatas: list[AttentionMetadata] = [] + self.initialize_attn_backend() + + # Forward meta store the global meta information of the forward + self.forward_meta: HPUForwardMeta = None + self.is_warmuping = False + self.is_hpu_perf_breakdown_sync_mode = int(os.environ.get("HPU_PERF_BREAKDOWN_SYNC_MODE", 1)) == 1 + # Postprocess Env params + os.environ["INFERENCE_MSG_QUEUE_ID"] = str( + self.local_rank + int(self.parallel_config.engine_worker_queue_port) + ) + + if int(os.environ.get("HABANA_PROFILE", 0)) == 1: + step_start = int(os.environ.get("PROFILE_START", 0)) + step_end = int(os.environ.get("PROFILE_END", 4)) + import paddle.profiler as profiler + + self.prof = profiler.Profiler( + targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.CUSTOM_DEVICE], + scheduler=(step_start, step_end), + on_trace_ready=profiler.export_chrome_tracing("./profile"), + ) + self.prof.start() + + def exist_prefill(self): + """ + check whether prefill stage finished + """ + if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0: + return 1 + else: + return 0 + + def init_speculative_proposer(self): + """ + Init speculative proposer + """ + # if self.speculative_method == "ngram": + # self.proposer = NgramProposer(self.fd_config) + # elif self.speculative_method == "mtp": + # self.proposer = MTPProposer(self.fd_config, self.get_model(), + # self.local_rank, self.device_id, + # self.share_inputs) + # else: + # self.proposer = None + pass + + def _init_logits_processor(self, request): + """ + init logits processor for guided decoding + """ + assert self.guided_backend is not None, ( + "guided_backend is None, use " "--guided-decoding-backend to specify the backend at server startup." + ) + + if request.guided_json is not None: + schemata_key = ("json", request.guided_json) + elif request.guided_regex is not None: + schemata_key = ("regex", request.guided_regex) + elif request.guided_grammar is not None: + schemata_key = ("grammar", request.guided_grammar) + elif request.structural_tag is not None: + schemata_key = ("structural_tag", request.structural_tag) + + return self.guided_backend.get_logits_processor(schemata_key=schemata_key), schemata_key + + def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int = None): + """ + Process inputs for prefill tasks and insert it to share_inputs buffer + req_dict: A list of Request dict + num_running_requests: batch_size + """ + # NOTE(luotingdan): Lazy initialize kv cache + if "caches" not in self.share_inputs: + self.initialize_kv_cache() + + # NOTE(luotingdan): Set environment variable of prefill node + if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill": + os.environ["PREFILL_NODE_ONE_STEP_STOP"] = "1" + + req_len = len(req_dicts) + for i in range(req_len): + request = req_dicts[i] + idx = request.idx + length = len(request.prompt_token_ids) + + prefill_tokens = [] + if ( + request.guided_json is not None + or request.guided_regex is not None + or request.structural_tag is not None + or request.guided_grammar is not None + ): + logits_info, schemata_key = self._init_logits_processor(request) + request.logits_processor, request.logits_cached = logits_info + request.schemata_key = schemata_key + + # Is Decode Node + if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode": + prefill_tokens.append(request.prompt_token_ids[0]) + self.share_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1] + self.share_inputs["input_ids"][idx : idx + 1, 0] = request.prompt_token_ids[0] + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = length + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 1 + self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = 0 + self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = length + self.share_inputs["step_idx"][idx : idx + 1] = 1 + + if self.speculative_decoding: + num_prefill_send_token = self.speculative_config.num_speculative_tokens + 1 + self.share_inputs["draft_tokens"][idx : idx + 1, 0:num_prefill_send_token] = paddle.to_tensor( + request.draft_token_ids[0:num_prefill_send_token], dtype="int64" + ) + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = num_prefill_send_token + else: + self.share_inputs["pre_ids"][idx : idx + 1] = -1 + self.share_inputs["step_idx"][idx : idx + 1] = 0 + self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids) + + # Use chunked prefill + if self.cache_config.enable_chunked_prefill: + request.set("chunk_idx", 1) + logger.info(f"prefill_chunk_info: {request.prefill_chunk_info}") + token_chunk_size = request.prefill_chunk_info[0] + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size + self.share_inputs["input_ids"][idx, :token_chunk_size] = np.array( + request.prompt_token_ids[:token_chunk_size] + ) + self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = token_chunk_size + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) + self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) + else: + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) + self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length + self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length + + if len(request.eos_token_ids) < self.model_config.eos_tokens_lens: + request.eos_token_ids.append(request.eos_token_ids[0]) + self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) + + self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7) + self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95) + self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0) + self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0) + self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0) + + self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1) + self.share_inputs["max_dec_len"][idx : idx + 1] = request.get( + "max_tokens", self.model_config.max_model_len + ) + self.share_inputs["stop_flags"][idx : idx + 1] = False + + self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1] + self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = length + + if request.get("seed") is not None: + self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed") + encoder_block_num = len(request.get("block_tables")) + self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num + self.share_inputs["block_tables"][idx : idx + 1, :] = -1 + self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( + request.block_tables, dtype="int32" + ) + + if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: + stop_seqs_num = len(request.get("stop_seqs_len")) + for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): + request.stop_seqs_len.append(0) + self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32") + self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array( + request.get("stop_token_ids"), dtype="int64" + ) + + self.sampler.apply_logits_processor(idx, request.get("logits_processor"), prefill_tokens) + + self.share_inputs["not_need_stop"][0] = True + + if self.speculative_method in ["mtp"]: + self.proposer.insert_prefill_inputs(req_dicts, num_running_requests) + + def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): + """Set dummy prefill inputs to share_inputs""" + # NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token + max_dec_len = expected_decode_len + 1 + full_length = min(num_tokens // batch_size, self.parallel_config.max_model_len - max_dec_len) + input_length = int(full_length * self.cache_config.kv_cache_ratio) + block_num = ( + input_length + self.cache_config.block_size - 1 + ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num + + for i in range(batch_size): + idx = i + self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) + self.share_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1) + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = input_length + self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = input_length + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = input_length + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 + self.share_inputs["step_idx"][idx : idx + 1] = 0 + self.share_inputs["max_dec_len"][idx : idx + 1] = max_dec_len + self.share_inputs["stop_flags"][idx : idx + 1] = False + + self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1] + self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = input_length + + self.share_inputs["encoder_block_lens"][idx : idx + 1] = block_num + self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange( + idx * block_num, (idx + 1) * block_num, 1 + ) + + def _init_share_inputs(self, max_num_seqs: int): + """Initialize all share buffers for model inputs. + Note: In the future, we may abandon share buffers. + """ + self.MAX_INFER_SEED = 9223372036854775806 + self.share_inputs = {} + + self.share_inputs["pre_ids"] = paddle.full( + [max_num_seqs, self.parallel_config.max_model_len], -1, dtype="int64" + ) + self.share_inputs["input_ids"] = paddle.full( + [max_num_seqs, self.parallel_config.max_model_len], self.model_config.pad_token_id, dtype="int64" + ) + self.share_inputs["eos_token_id"] = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64") + self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32") + self.share_inputs["temperature"] = paddle.full( + [max_num_seqs, 1], self.model_config.temperature, dtype="float32" + ) + self.share_inputs["penalty_score"] = paddle.full( + [max_num_seqs, 1], self.model_config.penalty_score, dtype="float32" + ) + self.share_inputs["frequency_score"] = paddle.full( + [max_num_seqs, 1], self.model_config.frequency_score, dtype="float32" + ) + self.share_inputs["presence_score"] = paddle.full( + [max_num_seqs, 1], self.model_config.presence_score, dtype="float32" + ) + + self.share_inputs["min_dec_len"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64") + self.share_inputs["max_dec_len"] = paddle.full( + [max_num_seqs, 1], self.model_config.max_model_len, dtype="int64" + ) + self.share_inputs["min_length"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64") + self.share_inputs["max_length"] = paddle.full( + [max_num_seqs, 1], self.model_config.max_model_len, dtype="int64" + ) + self.share_inputs["seq_lens_this_time"] = paddle.full(max_num_seqs, 0, dtype="int32") + self.share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["step_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["step_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") + self.share_inputs["not_need_stop"] = paddle.full( + [1], False, dtype="bool" + ).cpu() # TODO(gongshaotian): move to pinnd memory + self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool") + self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64") + + self.share_inputs["bad_tokens"] = paddle.full([1], -1, dtype="int64") + self.share_inputs["next_tokens"] = paddle.full([max_num_seqs, 1], -1, dtype="int64") + self.share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool").cpu() + self.share_inputs["encoder_block_lens"] = paddle.full([max_num_seqs], 0, dtype="int32").cpu() + self.share_inputs["step_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32").cpu() + self.share_inputs["step_lens"] = paddle.full([1], 0, dtype="int32").cpu() + self.share_inputs["recover_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32").cpu() + self.share_inputs["recover_lens"] = paddle.full([1], 0, dtype="int32").cpu() + self.share_inputs["need_block_list"] = paddle.full([max_num_seqs], -1, dtype="int32").cpu() + self.share_inputs["need_block_len"] = paddle.full([1], 0, dtype="int32").cpu() + self.share_inputs["used_list_len"] = paddle.full([max_num_seqs], 0, dtype="int32").cpu() + self.share_inputs["infer_seed"] = paddle.full([max_num_seqs, 1], 0, dtype="int64").cpu() + self.share_inputs["first_token_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int64") + self.share_inputs["ori_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32").cpu() + self.share_inputs["system_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["system_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int32") + + self.share_inputs["ids_remove_padding"] = paddle.full( + [max_num_seqs * self.parallel_config.max_model_len], 0, dtype="int64" + ) + self.share_inputs["cum_offsets"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["padding_offset"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + # AttentionBackend buffers + self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + + # Initialize rotary position embedding + tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) + # TODO(gongshaotian): move to models + self.share_inputs["rope_emb"] = get_rope( + rotary_dim=self.model_config.head_dim, + position_ids=tmp_position_ids, + base=self.model_config.rope_theta, + model_config=self.model_config, + ) + + # Set block tables + pre_max_block_num = ( + self.parallel_config.max_model_len + self.cache_config.block_size - 1 + ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num + self.share_inputs["block_tables"] = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32").cpu() + + # Initialize free list + free_list = list( + range( + self.parallel_config.total_block_num - 2, + int(self.parallel_config.total_block_num * self.cache_config.kv_cache_ratio) - 1, + -1, + ) + ) + self.free_list_len = len(free_list) + self.share_inputs["free_list"] = paddle.to_tensor(free_list, dtype="int32").cpu() + self.share_inputs["free_list_len"] = paddle.full([1], self.free_list_len, dtype="int32").cpu() + + # Initialize stop seqs + self.share_inputs["stop_seqs_len"] = paddle.full([self.model_config.max_stop_seqs_num], 0, dtype="int32") + self.share_inputs["stop_seqs"] = paddle.full( + [self.model_config.max_stop_seqs_num, self.model_config.stop_seqs_max_len], -1, dtype="int32" + ) + if self.speculative_decoding: + max_draft_token_num = self.speculative_config.num_speculative_tokens + self.share_inputs["input_ids_cpu"] = paddle.full( + shape=[max_num_seqs, self.parallel_config.max_model_len], fill_value=1, dtype="int64" + ).cpu() + self.share_inputs["accept_tokens"] = paddle.full( + shape=[max_num_seqs, max_draft_token_num + 1], fill_value=0, dtype="int64" + ) + self.share_inputs["accept_num"] = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32") + self.share_inputs["draft_tokens"] = paddle.full( + shape=[max_num_seqs, max_draft_token_num + 1], fill_value=0, dtype="int64" + ) + + self.share_inputs["actual_draft_token_num"] = paddle.full( + shape=[max_num_seqs], fill_value=max_draft_token_num, dtype="int32" + ) + self.share_inputs["output_cum_offsets"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") + self.share_inputs["output_padding_offset"] = paddle.full( + shape=[max_num_seqs * (max_draft_token_num + 1)], fill_value=0, dtype="int32" + ) + + def _prepare_inputs(self) -> None: + """prepare the model inputs""" + from fastdeploy.model_executor.ops.intel_hpu import prepare_block_metadata + + ( + ids_remove_padding, + rotary_embs, + block_groups, + block_list, + block_indices, + block_offsets, + block_mapping, + attention_mask, + batch_ids, + total_batch, + is_prompt, + ) = prepare_block_metadata( + self.share_inputs["input_ids"], + self.share_inputs["rope_emb"], + self.share_inputs["block_tables"], + self.share_inputs["seq_lens_encoder"], + self.share_inputs["seq_lens_decoder"], + self.cache_config.block_size, + self.parallel_config.dtype, + ) + is_prompt = is_prompt.item() == 1 if is_prompt.item() > 0 else None + if is_prompt is True: + attention_mask = None + # cum_offsets = None + self.share_inputs["ids_remove_padding"] = ids_remove_padding + self.share_inputs["rotary_embs"] = rotary_embs + self.share_inputs["block_groups"] = block_groups + self.share_inputs["block_list"] = block_list + self.share_inputs["block_indices"] = block_indices + self.share_inputs["block_offsets"] = block_offsets + self.share_inputs["block_mapping"] = block_mapping + self.share_inputs["block_bias"] = attention_mask + self.share_inputs["block_size"] = self.cache_config.block_size + self.share_inputs["batch_ids"] = batch_ids + self.share_inputs["total_batch"] = total_batch.item() + self.share_inputs["is_prompt"] = is_prompt + self.initialize_forward_meta() + + def _prepare_sampler_inputs(self, sampled_ids) -> None: + if self.forward_meta.total_batch == self.share_inputs["temperature"].shape[0]: + self.sampling_metadata = SamplingMetadata( + temperature=self.share_inputs["temperature"], + top_p=self.share_inputs["top_p"], + step_idx=self.share_inputs["step_idx"], + prompt_ids=self.share_inputs["input_ids"], + pre_token_ids=self.share_inputs["pre_ids"], + stop_flags=self.share_inputs["stop_flags"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + frequency_penalties=self.share_inputs["frequency_score"], + presence_penalties=self.share_inputs["presence_score"], + repetition_penalties=self.share_inputs["penalty_score"], + min_dec_lens=self.share_inputs["min_dec_len"], + bad_words_token_ids=self.share_inputs["bad_tokens"], + eos_token_ids=self.share_inputs["eos_token_id"], + ) + else: + from fastdeploy.model_executor.ops.intel_hpu import fused_index_select + + ( + temperature, + top_p, + step_idx, + prompt_token_ids, + pre_token_ids, + stop_flags, + seq_lens_encoder, + seq_lens_decoder, + frequency_penalties, + presence_penalties, + repetition_penalties, + min_dec_lens, + ) = fused_index_select( + self.share_inputs["temperature"], + self.share_inputs["top_p"], + self.share_inputs["step_idx"], + self.share_inputs["input_ids"], + self.share_inputs["pre_ids"], + self.share_inputs["stop_flags"], + self.share_inputs["seq_lens_encoder"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["frequency_score"], + self.share_inputs["presence_score"], + self.share_inputs["penalty_score"], + self.share_inputs["min_dec_len"], + sampled_ids, + self.forward_meta.total_batch, + ) + + self.sampling_metadata = SamplingMetadata( + temperature=temperature, + top_p=top_p, + step_idx=step_idx, + prompt_ids=prompt_token_ids, + pre_token_ids=pre_token_ids, + stop_flags=stop_flags, + seq_lens_encoder=seq_lens_encoder, + seq_lens_decoder=seq_lens_decoder, + frequency_penalties=frequency_penalties, + presence_penalties=presence_penalties, + repetition_penalties=repetition_penalties, + min_dec_lens=min_dec_lens, + bad_words_token_ids=self.share_inputs["bad_tokens"], + eos_token_ids=self.share_inputs["eos_token_id"], + ) + + def load_model(self) -> None: + """load or download model""" + logger.info(f"Starting to load model {self.model_config.architectures[0]}") + time_before_load = time.perf_counter() + # 1. Load original model + model_loader = get_model_loader(load_config=self.fd_config.load_config) + self.model = model_loader.load_model(fd_config=self.fd_config) + # 1.1 Load RL dynamic model + if self.fd_config.load_config.dynamic_load_weight: + from fastdeploy.rl.dynamic_weight_manager import DynamicWeightManager + + self.dynamic_weight_manager = DynamicWeightManager(self.fd_config, self.model) + + # 2. Load lora model + + # 3. Load drafter model(for speculative decoding) + + # 4. Convert model to HPU format + self.model = convert_model(self.model) + + time_after_load = time.perf_counter() + logger.info(f"Model loading took {time_after_load - time_before_load} seconds") + + # 4. Init proposer for speculative method + self.init_speculative_proposer() + + def get_model(self) -> nn.Layer: + """get current model""" + return self.model + + def initialize_forward_meta(self): + """ + Initialize forward meta and attention meta data + """ + # Initialize forward meta + self.forward_meta = HPUForwardMeta.init_forward_meta(self.share_inputs, self.attn_backends[0]) + + # Initialzie attention meta data + for attn_backend in self.attn_backends: + attn_backend.init_attention_metadata(self.forward_meta) + + def clear_cache(self): + """Clear cached data from shared inputs and forward metadata.""" + self.share_inputs.pop("caches", None) + if self.forward_meta is not None: + self.forward_meta.clear_caches() + + def initialize_kv_cache(self) -> None: + """ + Initialize kv cache + """ + cache_kvs = {} + max_block_num = self.num_gpu_blocks + + kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) + + for i in range(self.model_config.num_hidden_layers): + cache_type = self.parallel_config.dtype + cache_kvs["key_caches_{}".format(i)] = paddle.full( + shape=kv_cache_shape, + fill_value=0, + dtype=cache_type, + ) + cache_kvs["value_caches_{}".format(i)] = paddle.full( + shape=kv_cache_shape, + fill_value=0, + dtype=cache_type, + ) + self.share_inputs["caches"] = list(cache_kvs.values()) + for value in cache_kvs.values(): + del value + + def initialize_attn_backend(self) -> None: + """ + Initialize attention backends and forward metadata + """ + assert len(self.attn_backends) == 0 + + # TODO(gongshaotian): Get rank from config + num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size + self.model_config.kv_num_heads = ( + int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size + ) + head_dim = self.model_config.head_dim + + # Get the attention backend + attn_cls = get_attention_backend() + attn_backend = attn_cls( + self.fd_config, kv_num_heads=self.model_config.kv_num_heads, num_heads=num_heads, head_dim=head_dim + ) + if attn_backend is None: + raise NotImplementedError( + "Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly." + ) + self.attn_backends.append(attn_backend) + + def _dummy_run( + self, + num_tokens: paddle.Tensor, + batch_size: paddle.Tensor, + expected_decode_len: int = 1, + in_capturing: bool = False, + ) -> paddle.Tensor: + """ + Use dummy inputs to run before formal execution. + Args: + num_tokens: + expected_decode_len: Expected number of tokens generated + """ + self._dummy_prefill_inputs( + num_tokens=num_tokens, batch_size=batch_size, expected_decode_len=expected_decode_len + ) + if self.speculative_method in ["mtp"]: + raise NotImplementedError("speculative sampling is not supported on Intel HPU.") + while True: + + # 1. Compute real num_tokens + self._prepare_inputs() + + # 2. Initialize attention backend and forward meta data + model_output = self.model(self.share_inputs["ids_remove_padding"], self.forward_meta) + + hiddden_states = rebuild_padding_v3_1( + model_output, + self.forward_meta.batch_ids, + self.forward_meta.total_batch, + self.forward_meta.seq_lens_encoder, + self.forward_meta.is_prompt, + ) + # 5. Execute spec decode + logits = self.model.compute_logits(hiddden_states) + + self._prepare_sampler_inputs(self.forward_meta.batch_ids) + sampled_token_ids = self.sampler( + logits, + self.sampling_metadata, + self.forward_meta.batch_ids, + self.forward_meta.seq_lens_encoder.shape[0], + self.rank, + self.local_rank, + ) + if self.parallel_config.tensor_parallel_size > 1: + dtype = sampled_token_ids.dtype + sampled_token_ids = sampled_token_ids.to("float32") + paddle.distributed.broadcast(sampled_token_ids, 0) + sampled_token_ids = sampled_token_ids.to(dtype) + + # 6. post process + model_output_data = ModelOutputData( + next_tokens=self.share_inputs["next_tokens"], + stop_flags=self.share_inputs["stop_flags"], + step_idx=self.share_inputs["step_idx"], + max_dec_len=self.share_inputs["max_dec_len"], + pre_ids=self.share_inputs["pre_ids"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + eos_token_id=self.share_inputs["eos_token_id"], + not_need_stop=self.share_inputs["not_need_stop"], + input_ids=self.share_inputs["input_ids"], + stop_nums=self.share_inputs["stop_nums"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + is_block_step=self.share_inputs["is_block_step"], + full_hidden_states=model_output, + msg_queue_id=self.parallel_config.msg_queue_id, + mp_rank=self.local_rank, + use_ep=self.parallel_config.use_ep, + draft_tokens=self.share_inputs["draft_tokens"] if self.speculative_decoding else None, + actual_draft_token_num=( + self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None + ), + accept_tokens=self.share_inputs["accept_tokens"] if self.speculative_decoding else None, + accept_num=self.share_inputs["accept_num"] if self.speculative_decoding else None, + ) + + post_process_hpu( + sampled_token_ids=sampled_token_ids, model_output=model_output_data, is_warmuping=self.is_warmuping + ) + + # 7. Updata 'infer_seed' and step_cuda() + self.share_inputs["infer_seed"].add_(self.infer_seed_increment) + self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED + step_intel_hpu(self.share_inputs, self.cache_config.block_size, self.parallel_config.max_model_len) + + if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0: + break + + def _update_chunked_prefill(self, tasks): + """ + 更新chunked prefill相关参数 + """ + if not self.cache_config.enable_chunked_prefill: + return + + for task in tasks: + if task.get("prefill_chunk_info", None) is None: + continue + + if task.chunk_idx > len(task.prefill_chunk_info): + continue + self.restore_chunked_prefill_request[task.request_id] = task + + for id, task in list(self.restore_chunked_prefill_request.items()): + idx = task.idx + logger.debug(f"{task.request_id} chunked prefill {task.chunk_idx}/{len(task.prefill_chunk_info)}") + start_idx = sum(task.prefill_chunk_info[: task.chunk_idx]) + if task.chunk_idx == len(task.prefill_chunk_info): + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 1 + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 + self.share_inputs["step_idx"][idx : idx + 1] = 1 + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0) + del self.restore_chunked_prefill_request[task.request_id] + else: + token_chunk_size = task.prefill_chunk_info[task.chunk_idx] + + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size + self.share_inputs["input_ids"][idx, :token_chunk_size] = np.array( + task.prompt_token_ids[start_idx : start_idx + token_chunk_size] + ) + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size + self.share_inputs["step_idx"][idx : idx + 1] = 0 + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0) + if self.speculative_decoding and self.proposer.is_chunk_prefill_enabled(): + self.proposer.update_task_chunk_prefill(task) + task.chunk_idx += 1 + + def _dummy_sampler_run(self) -> paddle.Tensor: + """ """ + pass + + def update_warmup_inputs(self, requests, is_decode=False): + for i in range(len(requests)): + request = requests[i] + idx = request["idx"] + length = len(request["input_ids"]) + self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(request["input_ids"]) + if is_decode: + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = length + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 1 + self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = 0 + self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = length + self.share_inputs["step_idx"][idx : idx + 1] = 1 + else: + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length + self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length + self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0 + self.share_inputs["step_idx"][idx : idx + 1] = 0 + + if len(request["eos_token_ids"]) < self.model_config.eos_tokens_lens: + request["eos_token_ids"].append(request["eos_token_ids"][0]) + self.share_inputs["eos_token_id"][:] = np.array(request["eos_token_ids"], dtype="int64").reshape(-1, 1) + + self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7) + self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95) + self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0) + self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0) + self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0) + + self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1) + self.share_inputs["max_dec_len"][idx : idx + 1] = request.get("max_tokens", 1) + self.share_inputs["stop_flags"][idx : idx + 1] = False + + self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1] + self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = length + + if request.get("seed") is not None: + self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed") + encoder_block_num = len(request["block_tables"]) + self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num + self.share_inputs["block_tables"][idx : idx + 1, :] = -1 + self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( + request["block_tables"], dtype="int32" + ) + + self.share_inputs["not_need_stop"][0] = True + + def warm_up_bucket(self) -> None: + max_prefill_batch = 3 # Hard-Code in FastDeploy/fastdeploy/engine/config.py + warmup_max_model_len = min( + int(os.environ.get("HPU_WARMUP_MODEL_LEN", 4096)), self.parallel_config.max_model_len + ) + prefill_batchs = [] + prefill_batch_step = int(os.environ.get("BATCH_STEP_PREFILL", 1)) + current_prefill_batch = prefill_batch_step + while current_prefill_batch <= max_prefill_batch: + prefill_batchs.append(int(current_prefill_batch)) + current_prefill_batch += prefill_batch_step + + max_prefill_length = self.cache_config.block_size + warmup_max_model_len + for prefill_batch in prefill_batchs: + for prefill_length in range( + self.cache_config.block_size, max_prefill_length, self.cache_config.block_size + ): + if prefill_length * prefill_batch > self.scheduler_config.max_num_batched_tokens: + continue + logger.info(f"Warmup prefill_batch: {prefill_batch}, prefill_length: {prefill_length} start") + requests = [ + { + "idx": i, + "input_ids": [5] * (prefill_length - 1), + "block_tables": list(range(prefill_length // self.cache_config.block_size)), + "eos_token_ids": [2], + } + for i in range(prefill_batch) + ] + self.update_warmup_inputs(requests, is_decode=False) + self.execute_model() + logger.info(f"warmup prefill_batch: {prefill_batch}, prefill_length: {prefill_length} done") + + decode_batchs = [] + decode_batch_step = int(os.environ.get("BATCH_STEP_DECODE", 4)) + current_decode_batch = decode_batch_step + while current_decode_batch <= self.scheduler_config.max_num_seqs: + decode_batchs.append(int(current_decode_batch)) + current_decode_batch += decode_batch_step + + decode_block_nums = [] + decode_block_num_step = int(os.environ.get("BLOCK_STEP_DECODE", 16)) + current_decode_block_num = decode_block_num_step + pre_max_block_num = ( + warmup_max_model_len + self.cache_config.block_size - 1 + ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num + while current_decode_block_num <= min( + self.num_gpu_blocks, pre_max_block_num * self.scheduler_config.max_num_seqs + ): + decode_block_nums.append(int(current_decode_block_num)) + current_decode_block_num += decode_block_num_step + + logger.info(f"warmup decode_batchs: {decode_batchs}, decode_block_nums: {decode_block_nums} start") + for decode_batch in decode_batchs: + for decode_block_num in decode_block_nums: + if decode_block_num < decode_batch: + continue + if decode_block_num // decode_batch * self.cache_config.block_size > warmup_max_model_len: + continue + blocks = [decode_block_num // decode_batch for _ in range(decode_batch)] + remain_block_num = decode_block_num % decode_batch + b = 0 + while remain_block_num > 0: + blocks[b] += 1 + remain_block_num -= 1 + b += 1 + if blocks[0] * self.cache_config.block_size > warmup_max_model_len: + continue + logger.info(f"warmup decode_batch: {decode_batch}, decode_block_num: {decode_block_num} start") + requests = [ + { + "idx": i, + "input_ids": [5] * (blocks[i] * self.cache_config.block_size - 1), + "block_tables": list(range(blocks[i])), + "eos_token_ids": [2], + } + for i in range(decode_batch) + ] + self.update_warmup_inputs(requests, is_decode=True) + self.execute_model() + logger.info(f"Warmup decode_batch: {decode_batch}, decode_block_num: {decode_block_num} done") + self.share_inputs["not_need_stop"][0] = False + logger.info("Warmup bucket done") + + def capture_model(self) -> None: + """ + Trigger CUDA Graph capture for all shapes in 'CudaGraphConfig.cudagraph_capture_sizes' + """ + if not self.use_cudagraph: + logger.info("Skipping CUDA graph capture. Please check GraphOptimizationConfig") + return + time_before_capture = time.perf_counter() + expected_decode_len = 1 + capture_sizes = self.cudagraph_capture_sizes.copy() + for batch_size in sorted(capture_sizes, reverse=True): + self._dummy_run( + num_tokens=self.parallel_config.max_model_len, + batch_size=batch_size, + in_capturing=True, + expected_decode_len=expected_decode_len, + ) + logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}") + + time_after_capture = time.perf_counter() + logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds") + + def _get_skip_idx(self, model_forward_batch): + """ + Get the index of the request that needs to be skipped during execution. + Args: + model_forward_batch: A list of requests to be executed by this runner. + Returns: + A list of indices corresponding to the requests that need to be skipped. + """ + skip_idx_list = [] + if not self.parallel_config.enable_chunked_prefill or self.guided_backend is None: + return skip_idx_list + + for task in model_forward_batch: + if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info): + continue + skip_idx_list.append(task.idx) + + for task in self.restore_chunked_prefill_request.values(): + if task.idx in skip_idx_list or task.chunk_idx >= len(task.prefill_chunk_info): + continue + skip_idx_list.append(task.idx) + + return skip_idx_list + + def execute_model( + self, + model_forward_batch: Optional[List[Request]] = None, + ) -> Optional[ModelRunnerOutput]: + """ + The Entrance of model execute. + Args: + model_forward_batch: 'Request' contains information related to prompt and is an abstract + class at the server level, which is too granular for ModelRunner. + We plan to replace it with 'ModelForwardBatch'. + intermediate_tensors: + """ + # # 1. Prepare inputs of model and decoder. + start_time = time.time() + self._prepare_inputs() + # self.share_inputs["ids_remove_padding"].cpu() + # # 2. Padding inputs for cuda grph + end_time = time.time() + execution_time = (end_time - start_time) * 1000 + real_bs = self.share_inputs["ids_remove_padding"].shape[0] + hpu_model_runner_profile_logger.info(f"_prepare_inputs time(ms): {execution_time}, BT={real_bs}") + start_time = time.time() + # # 3. Execute model + model_output = self.model(self.share_inputs["ids_remove_padding"], self.forward_meta) + if self.is_hpu_perf_breakdown_sync_mode: + model_output.cpu() + end_time = time.time() + execution_time = (end_time - start_time) * 1000 + hpu_model_runner_profile_logger.info( + f"Model execution time(ms): {execution_time}, BT={real_bs}, block_list_shape={self.share_inputs['block_list'].shape}, block_indices_shape={self.share_inputs['block_indices'].shape}" + ) + + start_time = time.time() + start_time0 = time.time() + hiddden_states = rebuild_padding_v3_1( + model_output, + self.forward_meta.batch_ids, + self.forward_meta.total_batch, + self.forward_meta.seq_lens_encoder, + self.forward_meta.is_prompt, + ) + end_time0 = time.time() + execution_time0 = (end_time0 - start_time0) * 1000 + hpu_model_runner_profile_logger.info(f"RebuildPadding execution time(ms): {execution_time0}, BT={real_bs}") + # # 4. Compute logits, Sample + start_time1 = time.time() + logits = self.model.compute_logits(hiddden_states) + end_time1 = time.time() + execution_time1 = (end_time1 - start_time1) * 1000 + hpu_model_runner_profile_logger.info(f"ComputeLogits execution time(ms): {execution_time1}, BT={real_bs}") + + # data = np.random.rand(self.scheduler_config.max_num_seqs, self.model_config.vocab_size).astype(np.float32) + # logits = paddle.to_tensor(data, dtype='bfloat16') + start_time2 = time.time() + self._prepare_sampler_inputs(self.forward_meta.batch_ids) + sampled_token_ids = self.sampler( + logits, + self.sampling_metadata, + self.forward_meta.batch_ids, + self.forward_meta.seq_lens_encoder.shape[0], + self.rank, + self.local_rank, + ) + if self.parallel_config.tensor_parallel_size > 1: + dtype = sampled_token_ids.dtype + sampled_token_ids = sampled_token_ids.to("float32") + paddle.distributed.broadcast(sampled_token_ids, 0) + sampled_token_ids = sampled_token_ids.to(dtype) + if self.is_hpu_perf_breakdown_sync_mode: + sampled_token_ids.cpu() + end_time2 = time.time() + execution_time2 = (end_time2 - start_time2) * 1000 + hpu_model_runner_profile_logger.info(f"Sampler execution time(ms): {execution_time2}, BT={real_bs}") + # 5. Post Process + start_time3 = time.time() + model_output_data = ModelOutputData( + next_tokens=self.share_inputs["next_tokens"], + stop_flags=self.share_inputs["stop_flags"], + step_idx=self.share_inputs["step_idx"], + max_dec_len=self.share_inputs["max_dec_len"], + pre_ids=self.share_inputs["pre_ids"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + eos_token_id=self.share_inputs["eos_token_id"], + not_need_stop=self.share_inputs["not_need_stop"], + input_ids=self.share_inputs["input_ids"], + stop_nums=self.share_inputs["stop_nums"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + is_block_step=self.share_inputs["is_block_step"], + full_hidden_states=model_output, + msg_queue_id=self.parallel_config.msg_queue_id, + mp_rank=self.local_rank, + use_ep=self.parallel_config.use_ep, + draft_tokens=self.share_inputs["draft_tokens"] if self.speculative_decoding else None, + actual_draft_token_num=self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None, + accept_tokens=self.share_inputs["accept_tokens"] if self.speculative_decoding else None, + accept_num=self.share_inputs["accept_num"] if self.speculative_decoding else None, + ) + + # if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill": + # skip_save_output = True + # else: + # skip_save_output = False + post_process_hpu( + sampled_token_ids=sampled_token_ids, model_output=model_output_data, is_warmuping=self.is_warmuping + ) + end_time3 = time.time() + execution_time3 = (end_time3 - start_time3) * 1000 + hpu_model_runner_profile_logger.info(f"PostProcessHpu execution time(ms): {execution_time3}, BT={real_bs}") + end_time = time.time() + execution_time = (end_time - start_time) * 1000 + hpu_model_runner_profile_logger.info(f"PostProcessing execution time(ms): {execution_time}, BT={real_bs}") + + # 6. Speculative decode + if self.speculative_decoding: + if self.speculative_method == "mtp": + self.proposer.run(full_hidden_states=hiddden_states) + else: + self.proposer.run(share_inputs=self.share_inputs) + + # 7. Updata 'infer_seed' and step_cuda() + self.share_inputs["infer_seed"].add_(self.infer_seed_increment) + self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED + start_time = time.time() + step_intel_hpu(self.share_inputs, self.cache_config.block_size, self.parallel_config.max_model_len) + end_time = time.time() + execution_time = (end_time - start_time) * 1000 + hpu_model_runner_profile_logger.info(f"StepPaddle execution time(ms): {execution_time}, BT={real_bs}") + self._update_chunked_prefill(model_forward_batch) + self._add_cache(model_forward_batch) + + if int(os.environ.get("HABANA_PROFILE", 0)) == 1: + self.prof.step() + return None + + def _add_cache(self, model_forward_batch) -> None: + """ + Add cache for guided decoding. + """ + if self.guided_backend is None: + return + + for request in model_forward_batch: + logits_cached = request.get("logits_cached", None) + if logits_cached is None or logits_cached: + continue + + request.logits_cached = True + if isinstance(request.logits_processor, LogitsProcessorBase): + self.guided_backend.add_cache(request.schemata_key, request.logits_processor) + else: + self.guided_backend.add_cache(request.schemata_key, request.logits_processor.result()) + + def _execute_empty_input(self) -> None: + """ + In certain scenarios, such as during EP, + the runner needs to execute partial modules of the model without input data. + This requires the model to implement the `empty_input_forward` method. + """ + if hasattr(self.model, "empty_input_forward"): + self.model.empty_input_forward() + else: + raise ValueError(f"{type(self.model)} has no attribute 'empty_input_forward") + + def profile_run(self) -> None: + """Execute a forward pass with dummy inputs to profile the memory usage of the model.""" + + # Initialize kv cache for profile run. After profile run kv cache will be reset. + # TODO(gongshaotian): Optimize the management logic of kvcache + self.num_gpu_blocks = self.parallel_config.total_block_num + self.initialize_kv_cache() + + # 1. Profile with multimodal encoder & encoder cache + + # 2. Dummy run + self._dummy_run( + num_tokens=self.scheduler_config.max_num_batched_tokens, + batch_size=min(self.scheduler_config.max_num_seqs, 3), + ) + + # 3. gc + self.clear_cache() + + if self.speculative_method in ["mtp"]: + self.proposer.clear_dummy_input() + + def update_share_input_block_num(self, num_gpu_blocks: int) -> None: + """ + Set a globally unified block number and update the model's shared input. + Args: + num_gpu_blocks: + """ + self.num_gpu_blocks = num_gpu_blocks + + # Reset block table and kv cache with global block num + self.initialize_kv_cache() + + # Reset free list + free_list = list( + range(self.num_gpu_blocks - 2, int(self.num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1, -1) + ) + self.free_list_len = len(free_list) + self.share_inputs.update( + { + "free_list": paddle.to_tensor(free_list, dtype="int32").cpu(), + "free_list_len": paddle.full([1], self.free_list_len, dtype="int32").cpu(), + } + ) + + self.parallel_config.do_profile = False + + if self.speculative_method in ["mtp"]: + self.proposer.update_block_num(num_gpu_blocks) + + def cal_theortical_kvcache(self): + """ + Calculate the total block memory required at the model level + TODO(gongshaotian): Move to Attention Backend + """ + """ + Byte of dtype: + - default(bf16): 2 + - cache_int8: 1 + - cache_int4: + """ + cache_quant_dtype = None + if ( + self.quant_config + and hasattr(self.quant_config, "kv_cache_quant_type") + and self.quant_config.kv_cache_quant_type is not None + ): + cache_quant_dtype = self.quant_config.kv_cache_quant_type + + if cache_quant_dtype is not None: # int8, int8_zp, fp8, fp8_zp + byte_of_dtype = 1 + else: # default + byte_of_dtype = 2 + + hidden_dim = self.model_config.head_dim * self.model_config.kv_num_heads + # NOTE(liuzichang): Implement multi-layer MTP architecture in the future + num_layers = ( + self.model_config.num_hidden_layers + self.speculative_config.num_gpu_block_expand_ratio + if self.speculative_method in ["mtp"] + else self.model_config.num_hidden_layers + ) + required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v + return required_memory + + def not_need_stop(self) -> bool: + """ """ + return self.share_inputs["not_need_stop"][0] diff --git a/fastdeploy/worker/hpu_worker.py b/fastdeploy/worker/hpu_worker.py new file mode 100644 index 000000000..af908c8e5 --- /dev/null +++ b/fastdeploy/worker/hpu_worker.py @@ -0,0 +1,213 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import gc +import os +import time +from typing import List, Optional + +import paddle +import paddle.nn as nn +from paddle.base import core + +from fastdeploy.config import FDConfig +from fastdeploy.engine.request import Request +from fastdeploy.utils import get_logger, set_random_seed +from fastdeploy.worker.hpu_model_runner import HPUModelRunner +from fastdeploy.worker.output import ModelRunnerOutput +from fastdeploy.worker.worker_base import WorkerBase + +logger = get_logger("hpu_worker", "hpu_worker.log") + + +def max_memory_allocated(device_id: int) -> int: + return core.device_memory_stat_peak_value("Allocated", device_id) + + +def max_memory_reserved(device_id: int) -> int: + return core.device_memory_stat_peak_value("Reserved", device_id) + + +def reset_max_memory_allocated(device_id: int) -> None: + core.device_memory_stat_reset_peak_value("Allocated", device_id) + + +def reset_max_memory_reserved(device_id: int) -> None: + core.device_memory_stat_reset_peak_value("Reserved", device_id) + + +class HpuWorker(WorkerBase): + def __init__( + self, + fd_config: FDConfig, + local_rank: int, + rank: int, + ): + super().__init__( + fd_config=fd_config, + local_rank=local_rank, + rank=rank, + ) + pass + + def init_device(self): + """ + Initialize device and construct model runner + """ + if paddle.is_compiled_with_custom_device("intel_hpu"): + # Set environment variable + self.device_ids = self.parallel_config.device_ids.split(",") + logger.info( + f"Using Intel HPU device with local rank => device id: {int(self.device_ids[self.local_rank])} as module id" + ) + intel_hpus_module_id = int(self.device_ids[self.local_rank]) + self.device = f"intel_hpu:{intel_hpus_module_id}" + paddle.device.set_device(self.device) + paddle.set_default_dtype(self.parallel_config.dtype) + + gc.collect() + paddle.device.cuda.empty_cache() + else: + raise RuntimeError(f"Not support device type: {self.device_config.device}") + + set_random_seed(self.fd_config.model_config.seed) + # Construct model runner + self.model_runner: HPUModelRunner = HPUModelRunner( + fd_config=self.fd_config, + device=self.device, + device_id=self.device_ids[self.local_rank], + rank=self.rank, + local_rank=self.local_rank, + ) + + def exist_prefill(self): + """ + check whether prefill stage exist + """ + return self.model_runner.exist_prefill() + + def determine_available_memory(self) -> int: + """ + Profiles the peak memory usage of the model to determine how much + memory can be used for KV cache without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + Tip: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # 1. Record memory state before profile run + start_time = time.perf_counter() + module_id = int(self.device_ids[self.local_rank]) + reset_max_memory_allocated(module_id) + reset_max_memory_reserved(module_id) + paddle_reserved_mem_before_run = max_memory_reserved(module_id) + paddle_allocated_mem_before_run = max_memory_allocated(module_id) # not reserved + + logger.info( + ( + "Before running the profile, the memory usage info is as follows:", + f"\nPaddle reserved memory: {paddle_reserved_mem_before_run}", + f"\nPaddle allocated memory: {paddle_allocated_mem_before_run}", + ) + ) + + # 2. Profile run + self.model_runner.profile_run() + + # 3. Statistical memory information + paddle_reserved_mem_after_run = max_memory_reserved(module_id) + paddle_allocated_mem_after_run = max_memory_allocated(module_id) + + one_mb = 1024 * 1024 + one_gb = 1024 * one_mb + hpu_reserved_memory = 768 * one_mb # 768MB reserved for not paddle use memory + hpu_total_memory = 96 * one_gb # 96GB HPU memory + peak_memory = paddle_allocated_mem_after_run + hpu_reserved_memory + available_kv_cache_memory = hpu_total_memory * self.cache_config.gpu_memory_utilization - peak_memory + + end_time = time.perf_counter() + logger.info( + ( + "After running the profile, the memory usage info is as follows:", + f"\nPaddle reserved memory: {paddle_reserved_mem_after_run}", + f"\nPaddle allocated memory: {paddle_allocated_mem_after_run}", + f"\nAvailable KV Cache meomory: {available_kv_cache_memory}", + f"Profile time: {end_time - start_time}", + ) + ) + + return available_kv_cache_memory # return to caculate the block num in this device + + def load_model(self) -> None: + """Load model""" + self.model_runner.load_model() + + def get_model(self) -> nn.Layer: + """Get current model""" + return self.model_runner.get_model() + + def initialize_cache(self, num_gpu_blocks: int) -> None: + """Initialize the KV Cache with accurate num_gpu_blocks""" + # accurate cache size + self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) + + def execute_model( + self, + model_forward_batch: Optional[List[Request]] = None, + num_running_request: int = None, + ) -> Optional[ModelRunnerOutput]: + """ """ + output = self.model_runner.execute_model(model_forward_batch) + return output + + def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int) -> None: + """Process new requests and then start the decode loop + TODO(gongshaotian):The scheduler should schedule the handling of prefill, + and workers and modelrunners should not perceive it. + """ + self.model_runner.insert_prefill_inputs(req_dicts=req_dicts, num_running_requests=num_running_requests) + + def graph_optimize_and_warm_up_model(self) -> None: + """ + Perform the warm-up and the graph optimization + """ + # wait for all cards loading model completely. + if self.rank > 1: + paddle.distributed.barrier() + # 1. Warm up model + # NOTE(gongshaotian): may be not need warm_up at this place + if int(os.environ.get("HPU_WARMUP_BUCKET", 0)) == 1: + logger.info("Warmup bucket is enabled, start warmup bucket") + self.model_runner.is_warmuping = True + self.model_runner.warm_up_bucket() + self.model_runner.is_warmuping = False + else: + logger.info("Skipping warmup bucket, please set HPU_WARMUP_BUCKET=1 to enable it.") + + # 2. Triger cuda grpah capture + self.model_runner.capture_model() + + def check_health(self) -> bool: + """ """ + return True + + def cal_theortical_kvcache(self) -> int: + """Calculate the block memory required""" + return self.model_runner.cal_theortical_kvcache() diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 193970d0f..a830ef7ed 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -82,6 +82,10 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase: from fastdeploy.worker.metax_worker import MetaxWorker return MetaxWorker(fd_config=fd_config, local_rank=local_rank, rank=rank) + if current_platform.is_intel_hpu(): + from fastdeploy.worker.hpu_worker import HpuWorker + + return HpuWorker(fd_config=fd_config, local_rank=local_rank, rank=rank) def init_distributed_environment(seed: int = 20) -> Tuple[int, int]: @@ -89,21 +93,22 @@ def init_distributed_environment(seed: int = 20) -> Tuple[int, int]: # Global rank ranks = dist.get_world_size() dist_strategy = fleet.DistributedStrategy() + if ranks > 0: + dist_strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": ranks, + "pp_degree": 1, + "sharding_degree": 1, + } - dist_strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": ranks, - "pp_degree": 1, - "sharding_degree": 1, - } - - # Set control in tensor parallel - dist_strategy.tensor_parallel_configs = {"tensor_init_seed": seed} - fleet.init(is_collective=True, strategy=dist_strategy) - - # Local rank - local_rank = fleet.worker_index() + # Set control in tensor parallel + dist_strategy.tensor_parallel_configs = {"tensor_init_seed": seed} + fleet.init(is_collective=True, strategy=dist_strategy) + # Local rank + local_rank = fleet.worker_index() + else: + local_rank = 0 return ranks, local_rank diff --git a/setup.py b/setup.py index 1e9878936..6c79b6826 100644 --- a/setup.py +++ b/setup.py @@ -174,6 +174,8 @@ def get_device_type(): return "gcu" elif paddle.device.is_compiled_with_custom_device("metax_gpu"): return "metax-gpu" + elif paddle.is_compiled_with_custom_device("intel_hpu"): + return "intel-hpu" else: return "cpu"