mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 01:22:59 +08:00
[BugFix] Fix low prediction accuracy of deepseekv3 (#2798)
This commit is contained in:
@@ -41,7 +41,8 @@ from fastdeploy.config import FDConfig
|
|||||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||||
AttentionBackend, AttentionMetadata)
|
AttentionBackend, AttentionMetadata)
|
||||||
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
from fastdeploy.model_executor.layers.attention.utils import \
|
||||||
|
init_rank_and_device_id
|
||||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||||
|
|
||||||
|
|
||||||
@@ -185,6 +186,8 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
# MLA
|
# MLA
|
||||||
metadata.max_enc_len_this_time = metadata.set_max_lengths[1]
|
metadata.max_enc_len_this_time = metadata.set_max_lengths[1]
|
||||||
metadata.max_dec_len_this_time = metadata.set_max_lengths[2]
|
metadata.max_dec_len_this_time = metadata.set_max_lengths[2]
|
||||||
|
forward_meta.max_enc_len_this_time = metadata.set_max_lengths[1]
|
||||||
|
forward_meta.max_dec_len_this_time = metadata.set_max_lengths[2]
|
||||||
|
|
||||||
# pd_disaggregation
|
# pd_disaggregation
|
||||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||||
@@ -375,9 +378,6 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
speculate_decoder = self.speculative_method is not None
|
speculate_decoder = self.speculative_method is not None
|
||||||
speculate_max_tokens = self.speculate_max_draft_token_num
|
speculate_max_tokens = self.speculate_max_draft_token_num
|
||||||
|
|
||||||
decode_stage = forward_meta.is_decode_batch
|
|
||||||
prefill_stage = not (forward_meta.is_decode_batch)
|
|
||||||
|
|
||||||
if self.use_pd_disaggregation:
|
if self.use_pd_disaggregation:
|
||||||
metadata.kv_signal_data_list[
|
metadata.kv_signal_data_list[
|
||||||
layer.layer_id] = init_signal_layerwise(
|
layer.layer_id] = init_signal_layerwise(
|
||||||
@@ -387,8 +387,7 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(
|
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(
|
||||||
forward_meta, 'caches') else None
|
forward_meta, 'caches') else None
|
||||||
|
|
||||||
if prefill_stage:
|
if k is not None:
|
||||||
# 写入缓存
|
|
||||||
prefill_mla_write_cache(
|
prefill_mla_write_cache(
|
||||||
compressed_kv,
|
compressed_kv,
|
||||||
k_pe,
|
k_pe,
|
||||||
@@ -419,8 +418,7 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
return fmha_out
|
return fmha_out
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
if decode_stage:
|
if k is None:
|
||||||
# mla写入缓存
|
|
||||||
decode_mla_write_cache(
|
decode_mla_write_cache(
|
||||||
compressed_kv,
|
compressed_kv,
|
||||||
k_pe,
|
k_pe,
|
||||||
|
@@ -317,10 +317,7 @@ class DeepseekV3MLAAttention(nn.Layer):
|
|||||||
],
|
],
|
||||||
dtype=layernorm_out.dtype)
|
dtype=layernorm_out.dtype)
|
||||||
|
|
||||||
decode_stage = forward_meta.is_decode_batch
|
if forward_meta.max_enc_len_this_time:
|
||||||
prefill_stage = not (forward_meta.is_decode_batch)
|
|
||||||
|
|
||||||
if prefill_stage:
|
|
||||||
query = self.q_a_proj(layernorm_out)
|
query = self.q_a_proj(layernorm_out)
|
||||||
query = self.q_a_layernorm(query)
|
query = self.q_a_layernorm(query)
|
||||||
query = self.q_b_proj(query)
|
query = self.q_b_proj(query)
|
||||||
@@ -370,8 +367,7 @@ class DeepseekV3MLAAttention(nn.Layer):
|
|||||||
fmha_out_prefill.dtype)
|
fmha_out_prefill.dtype)
|
||||||
|
|
||||||
fmha_out = fmha_out + fmha_out_prefill
|
fmha_out = fmha_out + fmha_out_prefill
|
||||||
|
if forward_meta.max_dec_len_this_time:
|
||||||
if decode_stage:
|
|
||||||
query = self.q_a_proj(layernorm_out)
|
query = self.q_a_proj(layernorm_out)
|
||||||
query = self.q_a_layernorm(query)
|
query = self.q_a_layernorm(query)
|
||||||
ln_out_or_q_c = query
|
ln_out_or_q_c = query
|
||||||
@@ -554,28 +550,6 @@ class DeepSeekV3Model(nn.Layer):
|
|||||||
prefix="deepseek_v3.norm",
|
prefix="deepseek_v3.norm",
|
||||||
)
|
)
|
||||||
|
|
||||||
def pre_process(self, forward_meta):
|
|
||||||
"""
|
|
||||||
"""
|
|
||||||
seq_lens_encoder = forward_meta.seq_lens_encoder
|
|
||||||
seq_lens_decoder = forward_meta.seq_lens_decoder
|
|
||||||
seq_lens_this_time = forward_meta.seq_lens_this_time
|
|
||||||
position_ids_shape = paddle.sum(seq_lens_this_time)
|
|
||||||
|
|
||||||
position_ids = paddle.empty(shape=position_ids_shape,
|
|
||||||
dtype=seq_lens_encoder.dtype)
|
|
||||||
mask_encoder_batch = paddle.empty(
|
|
||||||
shape=position_ids_shape,
|
|
||||||
dtype=seq_lens_encoder.dtype).unsqueeze(1)
|
|
||||||
|
|
||||||
get_position_ids_and_mask_encoder_batch(seq_lens_encoder,
|
|
||||||
seq_lens_decoder,
|
|
||||||
seq_lens_this_time,
|
|
||||||
position_ids,
|
|
||||||
mask_encoder_batch)
|
|
||||||
|
|
||||||
return position_ids, mask_encoder_batch
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
"""
|
"""
|
||||||
Load model parameters from a given state dictionary.
|
Load model parameters from a given state dictionary.
|
||||||
@@ -590,13 +564,13 @@ class DeepSeekV3Model(nn.Layer):
|
|||||||
self,
|
self,
|
||||||
ids_remove_padding: paddle.Tensor,
|
ids_remove_padding: paddle.Tensor,
|
||||||
forward_meta: ForwardMeta,
|
forward_meta: ForwardMeta,
|
||||||
|
position_ids: paddle.Tensor,
|
||||||
|
mask_encoder_batch: paddle.Tensor,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding)
|
hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding)
|
||||||
|
|
||||||
position_ids, mask_encoder_batch = self.pre_process(forward_meta)
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_layers):
|
||||||
hidden_states, residual = self.decoder_layers[i](
|
hidden_states, residual = self.decoder_layers[i](
|
||||||
@@ -650,6 +624,27 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
|||||||
logits[:, self.ori_vocab_size:] = -float("inf")
|
logits[:, self.ori_vocab_size:] = -float("inf")
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
def pre_process(self, forward_meta):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
seq_lens_encoder = forward_meta.seq_lens_encoder
|
||||||
|
seq_lens_decoder = forward_meta.seq_lens_decoder
|
||||||
|
seq_lens_this_time = forward_meta.seq_lens_this_time
|
||||||
|
position_ids_shape = paddle.sum(seq_lens_this_time)
|
||||||
|
position_ids = paddle.empty(shape=position_ids_shape,
|
||||||
|
dtype=seq_lens_encoder.dtype)
|
||||||
|
mask_encoder_batch = paddle.empty(
|
||||||
|
shape=position_ids_shape,
|
||||||
|
dtype=seq_lens_encoder.dtype).unsqueeze(1)
|
||||||
|
|
||||||
|
get_position_ids_and_mask_encoder_batch(seq_lens_encoder,
|
||||||
|
seq_lens_decoder,
|
||||||
|
seq_lens_this_time,
|
||||||
|
position_ids,
|
||||||
|
mask_encoder_batch)
|
||||||
|
|
||||||
|
return position_ids, mask_encoder_batch
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ids_remove_padding: paddle.Tensor,
|
ids_remove_padding: paddle.Tensor,
|
||||||
@@ -657,7 +652,9 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
hidden_states = self.model(ids_remove_padding, forward_meta)
|
position_ids, mask_encoder_batch = self.pre_process(forward_meta)
|
||||||
|
hidden_states = self.model(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta,
|
||||||
|
position_ids=position_ids, mask_encoder_batch=mask_encoder_batch)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user