[Intel HPU] Support intel hpu platform (#4161)

* [Intel HPU] Support intel hpu platform

* fix some issues

* apply precommit and move AttentionBackend_HPU

* fix format issue

* correct ops import

* fix ci issue

* update code in layers

* fix code style issue

* remove dense tp moe ep mode

* fix enc_dec_block_num

* fix rebase issue

* rename hpu to gaudi in readme

* rename ForwardMeta_HPU to HPUForwardMeta
This commit is contained in:
fmiao2372
2025-09-24 12:27:50 +08:00
committed by GitHub
parent a1c5d930bb
commit f1b5392e20
35 changed files with 2814 additions and 19 deletions

View File

@@ -17,12 +17,14 @@
import logging
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import Optional
from typing import TYPE_CHECKING, Dict, Optional
import paddle
from fastdeploy.model_executor.layers.attention import AttentionBackend
if TYPE_CHECKING:
from fastdeploy.model_executor.layers.attention import AttentionBackend_HPU
logger = logging.getLogger(__name__)
@@ -240,3 +242,116 @@ class DCUForwardMeta(ForwardMeta):
# Accumulated offset
cum_offsets: Optional[paddle.Tensor] = None
@dataclass
class HPUForwardMeta:
"""
HPUForwardMeta is used to store the global meta information of the forward on intel HPU.
"""
#
input_ids: paddle.Tensor
# attention meta
forward_mode: ForwardMode = ForwardMode.MIXED
#
ids_remove_padding: paddle.Tensor = None
#
seq_lens_encoder: Optional[paddle.Tensor] = None
#
seq_lens_decoder: Optional[paddle.Tensor] = None
#
seq_lens_this_time: Optional[paddle.Tensor] = None
#
cum_offsets: Optional[paddle.Tensor] = None
#
block_tables: Optional[paddle.Tensor] = None
#
block_groups: Optional[paddle.Tensor] = None
#
block_list: Optional[paddle.Tensor] = None
#
block_indices: Optional[paddle.Tensor] = None
#
block_offsets: Optional[paddle.Tensor] = None
#
block_mapping: Optional[paddle.Tensor] = None
#
attention_mask: Optional[paddle.Tensor] = None
#
block_size: Optional[paddle.Tensor] = None
#
batch_ids: Optional[paddle.Tensor] = None
#
total_batch: Optional[paddle.Tensor] = None
#
is_prompt: Optional[paddle.Tensor] = None
#
attn_backend: "AttentionBackend_HPU" = None
#
rotary_embs: Optional[paddle.Tensor] = None
#
caches: Optional[paddle.Tensor] = None
#
attn_mask: Optional[paddle.Tensor] = None
#
pre_caches_length: int = 0
@classmethod
def init_forward_meta(cls, share_inputs: Dict, attn_backend: "AttentionBackend_HPU"):
"""init forward meta"""
# TODO(gongshaotian): delete this func
is_prompt = share_inputs["is_prompt"]
forward_mode = ForwardMode.DECODE
if is_prompt:
forward_mode = ForwardMode.EXTEND
ret = cls(
forward_mode=forward_mode,
input_ids=share_inputs["input_ids"],
ids_remove_padding=share_inputs["ids_remove_padding"],
seq_lens_encoder=share_inputs["seq_lens_encoder"],
seq_lens_decoder=share_inputs["seq_lens_decoder"],
seq_lens_this_time=share_inputs["seq_lens_this_time"],
block_tables=share_inputs["block_tables"],
block_groups=share_inputs["block_groups"],
block_list=share_inputs["block_list"],
block_indices=share_inputs["block_indices"],
block_offsets=share_inputs["block_offsets"],
block_mapping=share_inputs["block_mapping"],
attention_mask=share_inputs["block_bias"],
block_size=share_inputs["block_size"],
total_batch=share_inputs["total_batch"],
batch_ids=share_inputs["batch_ids"],
is_prompt=share_inputs["is_prompt"],
attn_backend=attn_backend,
rotary_embs=share_inputs["rotary_embs"],
caches=share_inputs["caches"],
)
return ret
def clear_caches(self):
"""safe clear caches"""
if self.caches:
del self.caches