mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -30,6 +30,7 @@ class ForwardMode(IntEnum):
|
||||
"""
|
||||
Forward mode used during attention.
|
||||
"""
|
||||
|
||||
# Prefill and Extend mode
|
||||
EXTEND = auto()
|
||||
# Decode mode
|
||||
@@ -38,23 +39,24 @@ class ForwardMode(IntEnum):
|
||||
MIXED = auto()
|
||||
|
||||
def is_prefill(self):
|
||||
""" Is Extend mode """
|
||||
"""Is Extend mode"""
|
||||
return self == ForwardMode.EXTEND
|
||||
|
||||
def is_decode(self):
|
||||
""" Is Decode mode """
|
||||
"""Is Decode mode"""
|
||||
return self == ForwardMode.DECODE
|
||||
|
||||
def is_mixed(self):
|
||||
""" Is Mixed mode """
|
||||
"""Is Mixed mode"""
|
||||
return self == ForwardMode.MIXED
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardMeta():
|
||||
class ForwardMeta:
|
||||
"""
|
||||
ForwardMeta is used to store the global meta information of the model forward.
|
||||
"""
|
||||
|
||||
# Input tokens IDs
|
||||
input_ids: paddle.Tensor
|
||||
# Input tokens IDs of removed padding
|
||||
@@ -100,7 +102,7 @@ class ForwardMeta():
|
||||
caches: Optional[list[paddle.Tensor]] = None
|
||||
|
||||
def clear_caches(self):
|
||||
""" Safely clean up the caches """
|
||||
"""Safely clean up the caches"""
|
||||
if self.caches:
|
||||
del self.caches
|
||||
|
||||
@@ -110,6 +112,7 @@ class XPUForwardMeta(ForwardMeta):
|
||||
"""
|
||||
XPUForwardMeta is used to store the global meta information of the forward, and some XPU specific meta info.
|
||||
"""
|
||||
|
||||
# TODO(wanghaitao): Supplementary notes
|
||||
#
|
||||
encoder_batch_map: Optional[paddle.Tensor] = None
|
||||
|
Reference in New Issue
Block a user