Files
FastDeploy/fastdeploy/model_executor/layers/attention/base_attention_backend.py
Jiang-Jia-Jun 05c670e593 [Sync] Update to latest code (#2679)
* [Sync] Update to latest code

* Add new code files

* Add new code files

* update code

* Try to fix build.sh

* Try to fix build.sh

* Update code

* Update requirements.txt

* Update code

---------

Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
2025-07-03 15:43:53 +08:00

140 lines
3.8 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapt from
# https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/base_attn_backend.py
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
import paddle
from fastdeploy.worker.forward_meta import ForwardMeta
@dataclass
class AttentionMetadata(ABC):
pass
class AttentionBackend(ABC):
"""The base class of attention backends"""
@abstractmethod
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize the forward metadata."""
raise NotImplementedError()
def forward(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: paddle.nn.Layer,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""
Run a forward.
args:
q: The query tensor.
k: The key tensor.
v: The value tensor.
layer: The layer that will be used for the forward.
compressed_kv: optional compressed key-value cache (for MLA)
k_pe: optional key positional encoding (for MLA)
forward_meta: The forward metadata.
"""
if forward_meta.forward_mode.is_mixed():
return self.forward_mixed(
q,
k,
v,
qkv,
compressed_kv,
k_pe,
layer,
forward_meta,
)
elif forward_meta.forward_mode.is_decode():
return self.forward_decode(
q,
k,
v,
qkv,
compressed_kv,
k_pe,
layer,
forward_meta,
)
else:
return self.forward_extend(
q,
k,
v,
qkv,
compressed_kv,
k_pe,
layer,
forward_meta,
)
def forward_mixed(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: paddle.nn.Layer,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""Run a forward for mix."""
raise NotImplementedError()
def forward_decode(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: paddle.nn.Layer,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""Run a forward for decode."""
raise NotImplementedError()
def forward_extend(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: paddle.nn.Layer,
forward_meta: ForwardMeta,
) -> paddle.Tensor:
"""Run a forward for extend."""
raise NotImplementedError()