mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
【Inference Optimize】DeepSeek-v3 model inference performance optimization (#3455)
* DSK_OPT_01 * update FA3
This commit is contained in:
@@ -24,6 +24,11 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
|
|||||||
import paddle
|
import paddle
|
||||||
from paddle.nn.functional.flash_attention import flash_attn_unpadded
|
from paddle.nn.functional.flash_attention import flash_attn_unpadded
|
||||||
|
|
||||||
|
try:
|
||||||
|
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
|
||||||
|
except:
|
||||||
|
flash_attention_v3_varlen = None
|
||||||
|
|
||||||
from fastdeploy.model_executor.layers.attention.ops import (
|
from fastdeploy.model_executor.layers.attention.ops import (
|
||||||
get_block_shape_and_split_kv_block,
|
get_block_shape_and_split_kv_block,
|
||||||
init_kv_signal_per_query,
|
init_kv_signal_per_query,
|
||||||
@@ -92,6 +97,7 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
__infer_dynamic_dims_fields__ = ["attention_metadata"]
|
__infer_dynamic_dims_fields__ = ["attention_metadata"]
|
||||||
attention_metadata: MLAAttentionMetadata
|
attention_metadata: MLAAttentionMetadata
|
||||||
|
flash_attn_func: callable = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -148,6 +154,22 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
||||||
|
|
||||||
|
if self.flash_attn_func is None:
|
||||||
|
prop = paddle.device.cuda.get_device_properties()
|
||||||
|
cc = prop.major * 10 + prop.minor
|
||||||
|
is_current_sm_supported = cc >= 90
|
||||||
|
is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs())
|
||||||
|
if is_current_sm_supported and is_paddle_supported:
|
||||||
|
self.flash_attn_func = flash_attention_v3_varlen
|
||||||
|
print("The current platform supports Flash Attention V3.")
|
||||||
|
self.flash_attn_kwargs = {"softmax_scale": self.attn_softmax_scale}
|
||||||
|
else:
|
||||||
|
self.flash_attn_func = flash_attn_unpadded
|
||||||
|
self.flash_attn_kwargs = {"scale": self.attn_softmax_scale, "training": False}
|
||||||
|
print(
|
||||||
|
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
|
||||||
|
)
|
||||||
|
|
||||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||||
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
|
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
|
||||||
metadata = MLAAttentionMetadata()
|
metadata = MLAAttentionMetadata()
|
||||||
@@ -269,7 +291,7 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Flash注意力计算
|
# Flash注意力计算
|
||||||
fmha_out = flash_attn_unpadded(
|
fmha_out = self.flash_attn_func(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
@@ -277,9 +299,8 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.cu_seqlens_k,
|
forward_meta.cu_seqlens_k,
|
||||||
metadata.max_enc_len_this_time,
|
metadata.max_enc_len_this_time,
|
||||||
metadata.max_enc_len_this_time,
|
metadata.max_enc_len_this_time,
|
||||||
self.attn_softmax_scale,
|
causal=self.causal,
|
||||||
causal=True,
|
**self.flash_attn_kwargs,
|
||||||
training=False,
|
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
return fmha_out
|
return fmha_out
|
||||||
@@ -418,7 +439,7 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# FA
|
# FA
|
||||||
fmha_out = flash_attn_unpadded(
|
fmha_out = self.flash_attn_func(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
@@ -426,9 +447,8 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.cu_seqlens_k,
|
forward_meta.cu_seqlens_k,
|
||||||
metadata.max_enc_len_this_time,
|
metadata.max_enc_len_this_time,
|
||||||
metadata.max_enc_len_this_time,
|
metadata.max_enc_len_this_time,
|
||||||
self.attn_softmax_scale,
|
causal=self.causal,
|
||||||
causal=True,
|
**self.flash_attn_kwargs,
|
||||||
training=False,
|
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
return fmha_out
|
return fmha_out
|
||||||
|
@@ -316,30 +316,23 @@ class DeepseekV3MLAAttention(nn.Layer):
|
|||||||
mask_encoder_batch: paddle.Tensor,
|
mask_encoder_batch: paddle.Tensor,
|
||||||
):
|
):
|
||||||
""" """
|
""" """
|
||||||
layernorm_out = hidden_states
|
|
||||||
fmha_out = paddle.zeros(
|
|
||||||
shape=[
|
|
||||||
layernorm_out.shape[0],
|
|
||||||
self.num_attention_heads_tp * self.v_head_dim,
|
|
||||||
],
|
|
||||||
dtype=layernorm_out.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
|
# NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation.
|
||||||
query = self.q_a_proj(layernorm_out)
|
fmha_out = None
|
||||||
|
query = self.q_a_proj(hidden_states)
|
||||||
query = self.q_a_layernorm(query)
|
query = self.q_a_layernorm(query)
|
||||||
query = self.q_b_proj(query)
|
query = self.q_b_proj(query)
|
||||||
|
|
||||||
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
|
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
|
||||||
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
|
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
|
||||||
|
|
||||||
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
|
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||||
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
|
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
|
||||||
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
|
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
|
||||||
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
||||||
|
|
||||||
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
|
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
|
||||||
|
|
||||||
|
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
|
||||||
key_value = self.kv_b_proj(compressed_kv)
|
key_value = self.kv_b_proj(compressed_kv)
|
||||||
key_value = key_value.reshape(
|
key_value = key_value.reshape(
|
||||||
[
|
[
|
||||||
@@ -371,23 +364,9 @@ class DeepseekV3MLAAttention(nn.Layer):
|
|||||||
fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
|
fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
|
||||||
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)
|
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)
|
||||||
|
|
||||||
fmha_out = fmha_out + fmha_out_prefill
|
fmha_out = fmha_out_prefill
|
||||||
|
|
||||||
if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time
|
if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time
|
||||||
query = self.q_a_proj(layernorm_out)
|
|
||||||
query = self.q_a_layernorm(query)
|
|
||||||
ln_out_or_q_c = query
|
|
||||||
|
|
||||||
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
|
|
||||||
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
|
|
||||||
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
|
|
||||||
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
|
||||||
|
|
||||||
query = self.q_b_proj(ln_out_or_q_c)
|
|
||||||
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
|
|
||||||
|
|
||||||
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
|
|
||||||
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
|
|
||||||
|
|
||||||
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2])
|
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2])
|
||||||
|
|
||||||
q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
|
q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
|
||||||
@@ -416,6 +395,9 @@ class DeepseekV3MLAAttention(nn.Layer):
|
|||||||
.transpose([1, 0, 2])
|
.transpose([1, 0, 2])
|
||||||
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
|
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
|
||||||
)
|
)
|
||||||
|
if fmha_out is None:
|
||||||
|
fmha_out = fmha_out_decode
|
||||||
|
else:
|
||||||
fmha_out = fmha_out + fmha_out_decode
|
fmha_out = fmha_out + fmha_out_decode
|
||||||
|
|
||||||
output = self.o_proj(fmha_out)
|
output = self.o_proj(fmha_out)
|
||||||
|
Reference in New Issue
Block a user