Files
FastDeploy/fastdeploy/model_executor/layers/attention/base.py
2025-06-09 19:20:15 +08:00

396 lines
17 KiB
Python

"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import paddle
from paddle import nn
import fastdeploy
class Attention(nn.Layer):
"""
Attention Layer
"""
def __init__(
self,
inference_args,
prefix,
out_scale=-1,
use_neox_rotary_style=False,
rope_theta=10000.0,
rope_3d=False,
qkv_scale=None,
qkv_bias=None,
linear_shift=None,
linear_smooth=None,
):
"""
Initialize the attention layer with various parameters.
Args:
inference_args (dict or object): Contains arguments for inference, including
number of key-value heads, weight data type, activation data type, etc.
prefix (str): The name of the attention layer for identification purposes.
out_scale (float, optional): Output scale factor. Defaults to -1.
use_neox_rotary_style (bool, optional): Whether to use the NeoX rotary position
encoding style. Defaults to False.
rope_theta (float, optional): Theta value for the rope position encoding. Defaults to 10000.0.
qkv_scale (float or None, optional): Quantization scale for QKV weights.
Used only for certain quantization configurations. Defaults to None.
qkv_bias (Tensor or None, optional): Bias for QKV linear layer. Defaults to None.
linear_shift (float or None, optional): Linear shift factor used in
quantization. Used only for certain quantization configurations.
Defaults to None.
linear_smooth (float or None, optional): Linear smooth factor used in
quantization. Used only for certain quantization configurations.
Defaults to None.
"""
super().__init__()
self.inference_args = inference_args
self.nranks = inference_args.mp_size
self.kv_num_heads = inference_args.num_key_value_heads // self.nranks
self.head_dim = self.inference_args.head_dim
self.prefix = prefix
self.cache_k_scale_name = prefix + ".cachek_matmul.activation_quanter"
self.cache_v_scale_name = prefix + ".cachev_matmul.activation_quanter"
self.out_scale = out_scale
self.cache_k_zp_name = self.cache_k_scale_name + ".zero_point"
self.cache_v_zp_name = self.cache_v_scale_name + ".zero_point"
self.use_neox_rotary_style = use_neox_rotary_style
self.rope_theta = rope_theta
self.rope_3d = rope_3d
self._dtype = self._helper.get_default_dtype()
if self._dtype == "bfloat16":
self._fuse_kernel_compute_dtype = "bf16"
elif self._dtype == "float16":
self._fuse_kernel_compute_dtype = "fp16"
elif self._dtype == "float32":
self._fuse_kernel_compute_dtype = "fp32"
else:
raise ValueError(f"Just support float32, float16 and \
bfloat16 as default dtype, but received {self._dtype}")
self.cache_scale_dtype = (
self._dtype if self.inference_args.use_append_attn else "float32")
self.qkv_bias = qkv_bias
if inference_args.weight_dtype == "int8" and inference_args.act_dtype == "int8":
self.qkv_scale = qkv_scale
self.linear_shift = linear_shift
self.linear_smooth = linear_smooth
if (inference_args.cachekv_dtype == "int8"
or inference_args.cachekv_dtype == "int4"
or inference_args.cachekv_dtype == "float8_e4m3fn"):
self.set_cachekv_scale()
# qkv_bias fused with attention only when W8A8
if not (inference_args.weight_dtype == "int8"
and inference_args.act_dtype == "int8"):
self.qkv_bias = None
def set_cachekv_scale(self):
"""
Set cache key (K) and value (V) scaling factors.
This method initializes and sets the scaling factors for cache key (K) and value (V)
tensors, which are used in attention mechanisms to adjust the scale of the cache
representations. Additionally, it calculates and sets the inverse of these scaling
factors for the output cache K and V tensors.
Args:
None - This method does not take any explicit arguments as it relies on the
instance variables of the class, such as `self.kv_num_heads`,
`self.cache_k_scale_name`, `self.cache_v_scale_name`, and
`self.inference_args.cachekv_scale_dict` for its functionality.
Returns:
None - This method modifies the instance variables directly and does not return
any values.
"""
self.cache_k_scale = self.create_parameter(
shape=([self.kv_num_heads *
self.head_dim] if self.inference_args.is_channel_wise else
[self.kv_num_heads]),
dtype=self.cache_scale_dtype,
is_bias=False,
)
self.cache_v_scale = self.create_parameter(
shape=([self.kv_num_heads *
self.head_dim] if self.inference_args.is_channel_wise else
[self.kv_num_heads]),
dtype=self.cache_scale_dtype,
is_bias=False,
)
self.cache_k_out_scale = self.create_parameter(
shape=([self.kv_num_heads *
self.head_dim] if self.inference_args.is_channel_wise else
[self.kv_num_heads]),
attr=None,
dtype=self.cache_scale_dtype,
is_bias=False,
)
self.cache_v_out_scale = self.create_parameter(
shape=([self.kv_num_heads *
self.head_dim] if self.inference_args.is_channel_wise else
[self.kv_num_heads]),
attr=None,
dtype=self.cache_scale_dtype,
is_bias=False,
)
if self.cache_k_scale_name in self.inference_args.cachekv_scale_dict:
cache_k_scale = paddle.cast(
paddle.to_tensor(self.inference_args.cachekv_scale_dict[
self.cache_k_scale_name]),
self.cache_scale_dtype,
)
cache_k_out_scale = 1.0 / cache_k_scale
else:
if os.getenv("EP_DECODER_PERF_TEST", "False") == "True":
cache_k_scale = paddle.zeros(self.cache_k_scale.shape,
self.cache_k_scale.dtype)
cache_k_out_scale = paddle.zeros(self.cache_k_out_scale.shape,
self.cache_k_out_scale.dtype)
else:
raise KeyError(
f"{self.cache_k_scale_name} not found in scale dict")
if self.cache_v_scale_name in self.inference_args.cachekv_scale_dict:
cache_v_scale = paddle.cast(
paddle.to_tensor(self.inference_args.cachekv_scale_dict[
self.cache_v_scale_name]),
self.cache_scale_dtype,
)
cache_v_out_scale = 1.0 / cache_v_scale
else:
if os.getenv("EP_DECODER_PERF_TEST", "False") == "True":
cache_v_scale = paddle.zeros(self.cache_v_scale.shape,
self.cache_v_scale.dtype)
cache_v_out_scale = paddle.zeros(self.cache_v_out_scale.shape,
self.cache_v_out_scale.dtype)
else:
raise KeyError(
f"{self.cache_v_scale_name} not found in scale dict")
self.cache_k_scale.set_value(cache_k_scale)
self.cache_v_scale.set_value(cache_v_scale)
self.cache_k_out_scale.set_value(cache_k_out_scale)
self.cache_v_out_scale.set_value(cache_v_out_scale)
if self.inference_args.has_zero_point:
self.cache_k_zp = self.create_parameter(
shape=([self.kv_num_heads *
self.head_dim] if self.inference_args.is_channel_wise
else [self.kv_num_heads]),
dtype=self.cache_scale_dtype,
is_bias=False,
)
self.cache_v_zp = self.create_parameter(
shape=([self.kv_num_heads *
self.head_dim] if self.inference_args.is_channel_wise
else [self.kv_num_heads]),
dtype=self.cache_scale_dtype,
is_bias=False,
)
if self.cache_k_zp_name in self.inference_args.cachekv_scale_dict:
cache_k_zp = paddle.cast(
paddle.to_tensor(self.inference_args.cachekv_scale_dict[
self.cache_k_zp_name]),
self.cache_scale_dtype,
)
else:
cache_k_zp = paddle.zeros(
([self.kv_num_heads *
self.head_dim] if self.inference_args.is_channel_wise
else [self.kv_num_heads]),
dtype=self.cache_scale_dtype,
)
if self.cache_v_zp_name in self.inference_args.cachekv_scale_dict:
cache_v_zp = paddle.cast(
paddle.to_tensor(self.inference_args.cachekv_scale_dict[
self.cache_v_zp_name]),
self.cache_scale_dtype,
)
else:
cache_v_zp = paddle.zeros(
([self.kv_num_heads *
self.head_dim] if self.inference_args.is_channel_wise
else [self.kv_num_heads]),
dtype=self.cache_scale_dtype,
)
self.cache_k_zp.set_value(cache_k_zp)
self.cache_v_zp.set_value(cache_v_zp)
def forward(
self,
qkv,
input_ids,
rotary_embs,
rotary_emb_dims,
key_cache,
value_cache,
pre_key_cache,
pre_value_cache,
pre_caches_length,
attn_mask,
kv_signal_data,
**kwargs,
):
"""
Compute the attention for a single time step.
Args:
qkv (Tensor): The output of the linear transformation of query, key and value.
Shape: [batch_size, num_heads, seq_len, embed_dim // num_heads].
padding_offset (Tensor): The offset to be added to the sequence length when computing
the attention mask. Shape: [batch_size, 1].
input_ids (Tensor, optional): The input ids of the batch. Used for computing the
attention mask. Default: None. Shape: [batch_size, max_sequence_length].
rotary_embs (Tensor, optional): The rotary position embeddings. Default: None.
Shape: [num_heads, rotary_emb_dims].
rotary_emb_dims (int, optional): The dimension of the rotary position embeddings.
Default: None.
caches (List[Tensor], optional): The cache tensors used in the computation of the
attention. Default: None.
pre_caches (List[Tensor], optional): The pre-computed cache tensors used in the
computation of the attention. Default: None.
pre_caches_length (int, optional): The length of the pre-computed cache tensors.
Default: None.
attn_mask (Tensor, optional): The attention mask. Default: None.
Shape: [batch_size, max_sequence_length].
**kwargs (dict, optional): Additional keyword arguments passed along.
Returns:
Tensor: The output of the linear transformation after applying the attention.
Shape: [batch_size, embed_dim // num_heads].
Raises:
None.
"""
k_quant_scale = kwargs.get("k_quant_scale", None)
v_quant_scale = kwargs.get("v_quant_scale", None)
k_dequant_scale = kwargs.get("k_dequant_scale", None)
v_dequant_scale = kwargs.get("v_dequant_scale", None)
if not self.inference_args.use_dynamic_cachekv_quant:
k_quant_scale = getattr(self, "cache_k_scale", None)
v_quant_scale = getattr(self, "cache_v_scale", None)
k_dequant_scale = getattr(self, "cache_k_out_scale", None)
v_dequant_scale = getattr(self, "cache_v_out_scale", None)
cache_quant_type_str = self.inference_args.cache_quant_type
else:
cache_quant_type_str = "none"
if self.inference_args.use_append_attn:
out = fastdeploy.model_executor.ops.gpu.append_attention(
qkv,
key_cache,
value_cache,
kwargs.get("seq_lens_encoder", None),
kwargs.get("seq_lens_decoder", None),
kwargs.get("seq_lens_this_time", None),
kwargs.get("padding_offsets", None),
kwargs.get("cum_offsets", None),
kwargs.get("block_tables", None),
kwargs.get("encoder_batch_ids", None),
kwargs.get("encoder_tile_ids_per_batch", None),
kwargs.get("encoder_num_blocks", None),
kwargs.get("kv_batch_ids", None),
kwargs.get("kv_tile_ids_per_batch", None),
kwargs.get("kv_num_blocks", None),
kwargs.get("decoder_batch_ids", None),
kwargs.get("decoder_tile_ids_per_batch", None),
kwargs.get("decoder_num_blocks", None),
kwargs.get("set_max_lengths", None),
kwargs.get("max_len_kv", None),
rotary_embs,
attn_mask,
getattr(self, "qkv_bias", None),
getattr(self, "qkv_scale", None),
k_quant_scale,
v_quant_scale,
k_dequant_scale,
v_dequant_scale,
getattr(self, "cache_k_zp", None), # cache_k_zp
getattr(self, "cache_v_zp", None), # cache_v_zp
getattr(self, "linear_shift", None), # out_shifts
getattr(self, "linear_smooth", None), # out_smooths
kv_signal_data,
self._fuse_kernel_compute_dtype,
cache_quant_type_str, # cache_quant_type
self.use_neox_rotary_style,
self.rope_3d,
kwargs.get("max_input_length", -1),
self.inference_args.quant_max_bound,
self.inference_args.quant_min_bound,
self.out_scale, # out_linear_in_scale
kwargs.get("encoder_block_shape_q", 64),
kwargs.get("decoder_block_shape_q", 16),
kwargs.get("max_partition_size", 32768),
kwargs.get("encoder_max_partition_size", 32768),
self.inference_args.speculate_max_draft_token_num +
1, # speculate_max_draft_token_num
True, # causal
self.inference_args.speculate_method
is not None, # speculate_decoder
)[0]
else:
out = paddle.incubate.nn.functional.block_multihead_attention(
qkv,
key_cache,
value_cache,
kwargs.get("seq_lens_encoder", None),
kwargs.get("seq_lens_decoder", None),
kwargs.get("seq_lens_this_time", None),
kwargs.get("padding_offsets", None),
kwargs.get("cum_offsets", None),
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
kwargs.get("block_tables", None),
pre_key_cache,
pre_value_cache,
k_quant_scale,
v_quant_scale,
k_dequant_scale,
v_dequant_scale,
getattr(self, "qkv_scale", None),
getattr(self, "qkv_bias", None),
getattr(self, "linear_shift", None),
getattr(self, "linear_smooth", None),
kwargs.get("max_enc_len_this_time", None),
kwargs.get("max_dec_len_this_time", None),
rotary_embs,
attn_mask,
None, # tgt_mask
kwargs.get("max_input_length", -1),
kwargs.get("block_size", 64),
self.use_neox_rotary_style,
self.inference_args.use_dynamic_cachekv_quant,
quant_round_type=self.inference_args.quant_round_type,
quant_max_bound=self.inference_args.quant_max_bound,
quant_min_bound=self.inference_args.quant_min_bound,
out_scale=self.out_scale,
compute_dtype=self._fuse_kernel_compute_dtype,
rope_theta=self.rope_theta,
)[0]
return out