mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-10 19:10:20 +08:00
[Intel HPU] Support intel hpu platform (#4161)
* [Intel HPU] Support intel hpu platform * fix some issues * apply precommit and move AttentionBackend_HPU * fix format issue * correct ops import * fix ci issue * update code in layers * fix code style issue * remove dense tp moe ep mode * fix enc_dec_block_num * fix rebase issue * rename hpu to gaudi in readme * rename ForwardMeta_HPU to HPUForwardMeta
This commit is contained in:
@@ -17,12 +17,14 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.attention import AttentionBackend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.layers.attention import AttentionBackend_HPU
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -240,3 +242,116 @@ class DCUForwardMeta(ForwardMeta):
|
||||
|
||||
# Accumulated offset
|
||||
cum_offsets: Optional[paddle.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class HPUForwardMeta:
|
||||
"""
|
||||
HPUForwardMeta is used to store the global meta information of the forward on intel HPU.
|
||||
"""
|
||||
|
||||
#
|
||||
input_ids: paddle.Tensor
|
||||
|
||||
# attention meta
|
||||
forward_mode: ForwardMode = ForwardMode.MIXED
|
||||
|
||||
#
|
||||
ids_remove_padding: paddle.Tensor = None
|
||||
|
||||
#
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
seq_lens_decoder: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
seq_lens_this_time: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
cum_offsets: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
block_groups: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
block_list: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
block_indices: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
block_offsets: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
block_mapping: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
attention_mask: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
block_size: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
batch_ids: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
total_batch: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
is_prompt: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
attn_backend: "AttentionBackend_HPU" = None
|
||||
|
||||
#
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
caches: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
|
||||
#
|
||||
pre_caches_length: int = 0
|
||||
|
||||
@classmethod
|
||||
def init_forward_meta(cls, share_inputs: Dict, attn_backend: "AttentionBackend_HPU"):
|
||||
"""init forward meta"""
|
||||
# TODO(gongshaotian): delete this func
|
||||
is_prompt = share_inputs["is_prompt"]
|
||||
forward_mode = ForwardMode.DECODE
|
||||
if is_prompt:
|
||||
forward_mode = ForwardMode.EXTEND
|
||||
ret = cls(
|
||||
forward_mode=forward_mode,
|
||||
input_ids=share_inputs["input_ids"],
|
||||
ids_remove_padding=share_inputs["ids_remove_padding"],
|
||||
seq_lens_encoder=share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=share_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=share_inputs["seq_lens_this_time"],
|
||||
block_tables=share_inputs["block_tables"],
|
||||
block_groups=share_inputs["block_groups"],
|
||||
block_list=share_inputs["block_list"],
|
||||
block_indices=share_inputs["block_indices"],
|
||||
block_offsets=share_inputs["block_offsets"],
|
||||
block_mapping=share_inputs["block_mapping"],
|
||||
attention_mask=share_inputs["block_bias"],
|
||||
block_size=share_inputs["block_size"],
|
||||
total_batch=share_inputs["total_batch"],
|
||||
batch_ids=share_inputs["batch_ids"],
|
||||
is_prompt=share_inputs["is_prompt"],
|
||||
attn_backend=attn_backend,
|
||||
rotary_embs=share_inputs["rotary_embs"],
|
||||
caches=share_inputs["caches"],
|
||||
)
|
||||
return ret
|
||||
|
||||
def clear_caches(self):
|
||||
"""safe clear caches"""
|
||||
if self.caches:
|
||||
del self.caches
|
||||
|
Reference in New Issue
Block a user