mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
396 lines
17 KiB
Python
396 lines
17 KiB
Python
"""
|
|
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import os
|
|
|
|
import paddle
|
|
from paddle import nn
|
|
|
|
import fastdeploy
|
|
|
|
|
|
class Attention(nn.Layer):
|
|
"""
|
|
Attention Layer
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
inference_args,
|
|
prefix,
|
|
out_scale=-1,
|
|
use_neox_rotary_style=False,
|
|
rope_theta=10000.0,
|
|
rope_3d=False,
|
|
qkv_scale=None,
|
|
qkv_bias=None,
|
|
linear_shift=None,
|
|
linear_smooth=None,
|
|
):
|
|
"""
|
|
Initialize the attention layer with various parameters.
|
|
|
|
Args:
|
|
inference_args (dict or object): Contains arguments for inference, including
|
|
number of key-value heads, weight data type, activation data type, etc.
|
|
prefix (str): The name of the attention layer for identification purposes.
|
|
out_scale (float, optional): Output scale factor. Defaults to -1.
|
|
use_neox_rotary_style (bool, optional): Whether to use the NeoX rotary position
|
|
encoding style. Defaults to False.
|
|
rope_theta (float, optional): Theta value for the rope position encoding. Defaults to 10000.0.
|
|
qkv_scale (float or None, optional): Quantization scale for QKV weights.
|
|
Used only for certain quantization configurations. Defaults to None.
|
|
qkv_bias (Tensor or None, optional): Bias for QKV linear layer. Defaults to None.
|
|
linear_shift (float or None, optional): Linear shift factor used in
|
|
quantization. Used only for certain quantization configurations.
|
|
Defaults to None.
|
|
linear_smooth (float or None, optional): Linear smooth factor used in
|
|
quantization. Used only for certain quantization configurations.
|
|
Defaults to None.
|
|
"""
|
|
super().__init__()
|
|
self.inference_args = inference_args
|
|
self.nranks = inference_args.mp_size
|
|
self.kv_num_heads = inference_args.num_key_value_heads // self.nranks
|
|
self.head_dim = self.inference_args.head_dim
|
|
self.prefix = prefix
|
|
self.cache_k_scale_name = prefix + ".cachek_matmul.activation_quanter"
|
|
self.cache_v_scale_name = prefix + ".cachev_matmul.activation_quanter"
|
|
self.out_scale = out_scale
|
|
|
|
self.cache_k_zp_name = self.cache_k_scale_name + ".zero_point"
|
|
self.cache_v_zp_name = self.cache_v_scale_name + ".zero_point"
|
|
|
|
self.use_neox_rotary_style = use_neox_rotary_style
|
|
self.rope_theta = rope_theta
|
|
self.rope_3d = rope_3d
|
|
|
|
self._dtype = self._helper.get_default_dtype()
|
|
if self._dtype == "bfloat16":
|
|
self._fuse_kernel_compute_dtype = "bf16"
|
|
elif self._dtype == "float16":
|
|
self._fuse_kernel_compute_dtype = "fp16"
|
|
elif self._dtype == "float32":
|
|
self._fuse_kernel_compute_dtype = "fp32"
|
|
else:
|
|
raise ValueError(f"Just support float32, float16 and \
|
|
bfloat16 as default dtype, but received {self._dtype}")
|
|
|
|
self.cache_scale_dtype = (
|
|
self._dtype if self.inference_args.use_append_attn else "float32")
|
|
|
|
self.qkv_bias = qkv_bias
|
|
if inference_args.weight_dtype == "int8" and inference_args.act_dtype == "int8":
|
|
self.qkv_scale = qkv_scale
|
|
self.linear_shift = linear_shift
|
|
self.linear_smooth = linear_smooth
|
|
if (inference_args.cachekv_dtype == "int8"
|
|
or inference_args.cachekv_dtype == "int4"
|
|
or inference_args.cachekv_dtype == "float8_e4m3fn"):
|
|
self.set_cachekv_scale()
|
|
# qkv_bias fused with attention only when W8A8
|
|
if not (inference_args.weight_dtype == "int8"
|
|
and inference_args.act_dtype == "int8"):
|
|
self.qkv_bias = None
|
|
|
|
def set_cachekv_scale(self):
|
|
"""
|
|
Set cache key (K) and value (V) scaling factors.
|
|
|
|
This method initializes and sets the scaling factors for cache key (K) and value (V)
|
|
tensors, which are used in attention mechanisms to adjust the scale of the cache
|
|
representations. Additionally, it calculates and sets the inverse of these scaling
|
|
factors for the output cache K and V tensors.
|
|
|
|
Args:
|
|
None - This method does not take any explicit arguments as it relies on the
|
|
instance variables of the class, such as `self.kv_num_heads`,
|
|
`self.cache_k_scale_name`, `self.cache_v_scale_name`, and
|
|
`self.inference_args.cachekv_scale_dict` for its functionality.
|
|
|
|
Returns:
|
|
None - This method modifies the instance variables directly and does not return
|
|
any values.
|
|
"""
|
|
self.cache_k_scale = self.create_parameter(
|
|
shape=([self.kv_num_heads *
|
|
self.head_dim] if self.inference_args.is_channel_wise else
|
|
[self.kv_num_heads]),
|
|
dtype=self.cache_scale_dtype,
|
|
is_bias=False,
|
|
)
|
|
self.cache_v_scale = self.create_parameter(
|
|
shape=([self.kv_num_heads *
|
|
self.head_dim] if self.inference_args.is_channel_wise else
|
|
[self.kv_num_heads]),
|
|
dtype=self.cache_scale_dtype,
|
|
is_bias=False,
|
|
)
|
|
self.cache_k_out_scale = self.create_parameter(
|
|
shape=([self.kv_num_heads *
|
|
self.head_dim] if self.inference_args.is_channel_wise else
|
|
[self.kv_num_heads]),
|
|
attr=None,
|
|
dtype=self.cache_scale_dtype,
|
|
is_bias=False,
|
|
)
|
|
self.cache_v_out_scale = self.create_parameter(
|
|
shape=([self.kv_num_heads *
|
|
self.head_dim] if self.inference_args.is_channel_wise else
|
|
[self.kv_num_heads]),
|
|
attr=None,
|
|
dtype=self.cache_scale_dtype,
|
|
is_bias=False,
|
|
)
|
|
|
|
if self.cache_k_scale_name in self.inference_args.cachekv_scale_dict:
|
|
cache_k_scale = paddle.cast(
|
|
paddle.to_tensor(self.inference_args.cachekv_scale_dict[
|
|
self.cache_k_scale_name]),
|
|
self.cache_scale_dtype,
|
|
)
|
|
cache_k_out_scale = 1.0 / cache_k_scale
|
|
else:
|
|
if os.getenv("EP_DECODER_PERF_TEST", "False") == "True":
|
|
cache_k_scale = paddle.zeros(self.cache_k_scale.shape,
|
|
self.cache_k_scale.dtype)
|
|
cache_k_out_scale = paddle.zeros(self.cache_k_out_scale.shape,
|
|
self.cache_k_out_scale.dtype)
|
|
else:
|
|
raise KeyError(
|
|
f"{self.cache_k_scale_name} not found in scale dict")
|
|
|
|
if self.cache_v_scale_name in self.inference_args.cachekv_scale_dict:
|
|
cache_v_scale = paddle.cast(
|
|
paddle.to_tensor(self.inference_args.cachekv_scale_dict[
|
|
self.cache_v_scale_name]),
|
|
self.cache_scale_dtype,
|
|
)
|
|
cache_v_out_scale = 1.0 / cache_v_scale
|
|
else:
|
|
if os.getenv("EP_DECODER_PERF_TEST", "False") == "True":
|
|
cache_v_scale = paddle.zeros(self.cache_v_scale.shape,
|
|
self.cache_v_scale.dtype)
|
|
cache_v_out_scale = paddle.zeros(self.cache_v_out_scale.shape,
|
|
self.cache_v_out_scale.dtype)
|
|
else:
|
|
raise KeyError(
|
|
f"{self.cache_v_scale_name} not found in scale dict")
|
|
|
|
self.cache_k_scale.set_value(cache_k_scale)
|
|
self.cache_v_scale.set_value(cache_v_scale)
|
|
self.cache_k_out_scale.set_value(cache_k_out_scale)
|
|
self.cache_v_out_scale.set_value(cache_v_out_scale)
|
|
|
|
if self.inference_args.has_zero_point:
|
|
self.cache_k_zp = self.create_parameter(
|
|
shape=([self.kv_num_heads *
|
|
self.head_dim] if self.inference_args.is_channel_wise
|
|
else [self.kv_num_heads]),
|
|
dtype=self.cache_scale_dtype,
|
|
is_bias=False,
|
|
)
|
|
self.cache_v_zp = self.create_parameter(
|
|
shape=([self.kv_num_heads *
|
|
self.head_dim] if self.inference_args.is_channel_wise
|
|
else [self.kv_num_heads]),
|
|
dtype=self.cache_scale_dtype,
|
|
is_bias=False,
|
|
)
|
|
if self.cache_k_zp_name in self.inference_args.cachekv_scale_dict:
|
|
cache_k_zp = paddle.cast(
|
|
paddle.to_tensor(self.inference_args.cachekv_scale_dict[
|
|
self.cache_k_zp_name]),
|
|
self.cache_scale_dtype,
|
|
)
|
|
else:
|
|
cache_k_zp = paddle.zeros(
|
|
([self.kv_num_heads *
|
|
self.head_dim] if self.inference_args.is_channel_wise
|
|
else [self.kv_num_heads]),
|
|
dtype=self.cache_scale_dtype,
|
|
)
|
|
if self.cache_v_zp_name in self.inference_args.cachekv_scale_dict:
|
|
cache_v_zp = paddle.cast(
|
|
paddle.to_tensor(self.inference_args.cachekv_scale_dict[
|
|
self.cache_v_zp_name]),
|
|
self.cache_scale_dtype,
|
|
)
|
|
else:
|
|
cache_v_zp = paddle.zeros(
|
|
([self.kv_num_heads *
|
|
self.head_dim] if self.inference_args.is_channel_wise
|
|
else [self.kv_num_heads]),
|
|
dtype=self.cache_scale_dtype,
|
|
)
|
|
self.cache_k_zp.set_value(cache_k_zp)
|
|
self.cache_v_zp.set_value(cache_v_zp)
|
|
|
|
def forward(
|
|
self,
|
|
qkv,
|
|
input_ids,
|
|
rotary_embs,
|
|
rotary_emb_dims,
|
|
key_cache,
|
|
value_cache,
|
|
pre_key_cache,
|
|
pre_value_cache,
|
|
pre_caches_length,
|
|
attn_mask,
|
|
kv_signal_data,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Compute the attention for a single time step.
|
|
|
|
Args:
|
|
qkv (Tensor): The output of the linear transformation of query, key and value.
|
|
Shape: [batch_size, num_heads, seq_len, embed_dim // num_heads].
|
|
padding_offset (Tensor): The offset to be added to the sequence length when computing
|
|
the attention mask. Shape: [batch_size, 1].
|
|
input_ids (Tensor, optional): The input ids of the batch. Used for computing the
|
|
attention mask. Default: None. Shape: [batch_size, max_sequence_length].
|
|
rotary_embs (Tensor, optional): The rotary position embeddings. Default: None.
|
|
Shape: [num_heads, rotary_emb_dims].
|
|
rotary_emb_dims (int, optional): The dimension of the rotary position embeddings.
|
|
Default: None.
|
|
caches (List[Tensor], optional): The cache tensors used in the computation of the
|
|
attention. Default: None.
|
|
pre_caches (List[Tensor], optional): The pre-computed cache tensors used in the
|
|
computation of the attention. Default: None.
|
|
pre_caches_length (int, optional): The length of the pre-computed cache tensors.
|
|
Default: None.
|
|
attn_mask (Tensor, optional): The attention mask. Default: None.
|
|
Shape: [batch_size, max_sequence_length].
|
|
**kwargs (dict, optional): Additional keyword arguments passed along.
|
|
|
|
Returns:
|
|
Tensor: The output of the linear transformation after applying the attention.
|
|
Shape: [batch_size, embed_dim // num_heads].
|
|
|
|
Raises:
|
|
None.
|
|
"""
|
|
k_quant_scale = kwargs.get("k_quant_scale", None)
|
|
v_quant_scale = kwargs.get("v_quant_scale", None)
|
|
k_dequant_scale = kwargs.get("k_dequant_scale", None)
|
|
v_dequant_scale = kwargs.get("v_dequant_scale", None)
|
|
|
|
if not self.inference_args.use_dynamic_cachekv_quant:
|
|
k_quant_scale = getattr(self, "cache_k_scale", None)
|
|
v_quant_scale = getattr(self, "cache_v_scale", None)
|
|
k_dequant_scale = getattr(self, "cache_k_out_scale", None)
|
|
v_dequant_scale = getattr(self, "cache_v_out_scale", None)
|
|
cache_quant_type_str = self.inference_args.cache_quant_type
|
|
else:
|
|
cache_quant_type_str = "none"
|
|
|
|
if self.inference_args.use_append_attn:
|
|
out = fastdeploy.model_executor.ops.gpu.append_attention(
|
|
qkv,
|
|
key_cache,
|
|
value_cache,
|
|
kwargs.get("seq_lens_encoder", None),
|
|
kwargs.get("seq_lens_decoder", None),
|
|
kwargs.get("seq_lens_this_time", None),
|
|
kwargs.get("padding_offsets", None),
|
|
kwargs.get("cum_offsets", None),
|
|
kwargs.get("block_tables", None),
|
|
kwargs.get("encoder_batch_ids", None),
|
|
kwargs.get("encoder_tile_ids_per_batch", None),
|
|
kwargs.get("encoder_num_blocks", None),
|
|
kwargs.get("kv_batch_ids", None),
|
|
kwargs.get("kv_tile_ids_per_batch", None),
|
|
kwargs.get("kv_num_blocks", None),
|
|
kwargs.get("decoder_batch_ids", None),
|
|
kwargs.get("decoder_tile_ids_per_batch", None),
|
|
kwargs.get("decoder_num_blocks", None),
|
|
kwargs.get("set_max_lengths", None),
|
|
kwargs.get("max_len_kv", None),
|
|
rotary_embs,
|
|
attn_mask,
|
|
getattr(self, "qkv_bias", None),
|
|
getattr(self, "qkv_scale", None),
|
|
k_quant_scale,
|
|
v_quant_scale,
|
|
k_dequant_scale,
|
|
v_dequant_scale,
|
|
getattr(self, "cache_k_zp", None), # cache_k_zp
|
|
getattr(self, "cache_v_zp", None), # cache_v_zp
|
|
getattr(self, "linear_shift", None), # out_shifts
|
|
getattr(self, "linear_smooth", None), # out_smooths
|
|
kv_signal_data,
|
|
self._fuse_kernel_compute_dtype,
|
|
cache_quant_type_str, # cache_quant_type
|
|
self.use_neox_rotary_style,
|
|
self.rope_3d,
|
|
kwargs.get("max_input_length", -1),
|
|
self.inference_args.quant_max_bound,
|
|
self.inference_args.quant_min_bound,
|
|
self.out_scale, # out_linear_in_scale
|
|
kwargs.get("encoder_block_shape_q", 64),
|
|
kwargs.get("decoder_block_shape_q", 16),
|
|
kwargs.get("max_partition_size", 32768),
|
|
kwargs.get("encoder_max_partition_size", 32768),
|
|
self.inference_args.speculate_max_draft_token_num +
|
|
1, # speculate_max_draft_token_num
|
|
True, # causal
|
|
self.inference_args.speculate_method
|
|
is not None, # speculate_decoder
|
|
)[0]
|
|
else:
|
|
out = paddle.incubate.nn.functional.block_multihead_attention(
|
|
qkv,
|
|
key_cache,
|
|
value_cache,
|
|
kwargs.get("seq_lens_encoder", None),
|
|
kwargs.get("seq_lens_decoder", None),
|
|
kwargs.get("seq_lens_this_time", None),
|
|
kwargs.get("padding_offsets", None),
|
|
kwargs.get("cum_offsets", None),
|
|
kwargs.get("cu_seqlens_q", None),
|
|
kwargs.get("cu_seqlens_k", None),
|
|
kwargs.get("block_tables", None),
|
|
pre_key_cache,
|
|
pre_value_cache,
|
|
k_quant_scale,
|
|
v_quant_scale,
|
|
k_dequant_scale,
|
|
v_dequant_scale,
|
|
getattr(self, "qkv_scale", None),
|
|
getattr(self, "qkv_bias", None),
|
|
getattr(self, "linear_shift", None),
|
|
getattr(self, "linear_smooth", None),
|
|
kwargs.get("max_enc_len_this_time", None),
|
|
kwargs.get("max_dec_len_this_time", None),
|
|
rotary_embs,
|
|
attn_mask,
|
|
None, # tgt_mask
|
|
kwargs.get("max_input_length", -1),
|
|
kwargs.get("block_size", 64),
|
|
self.use_neox_rotary_style,
|
|
self.inference_args.use_dynamic_cachekv_quant,
|
|
quant_round_type=self.inference_args.quant_round_type,
|
|
quant_max_bound=self.inference_args.quant_max_bound,
|
|
quant_min_bound=self.inference_args.quant_min_bound,
|
|
out_scale=self.out_scale,
|
|
compute_dtype=self._fuse_kernel_compute_dtype,
|
|
rope_theta=self.rope_theta,
|
|
)[0]
|
|
|
|
return out
|