[BugFix] Fix low prediction accuracy of deepseekv3 (#2798)

This commit is contained in:
K11OntheBoat
2025-07-10 16:16:44 +08:00
committed by GitHub
parent 1e2319cbef
commit 24f934f1f9
2 changed files with 34 additions and 39 deletions

View File

@@ -41,7 +41,8 @@ from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
from fastdeploy.model_executor.layers.attention.utils import \
init_rank_and_device_id
from fastdeploy.worker.forward_meta import ForwardMeta
@@ -185,6 +186,8 @@ class MLAAttentionBackend(AttentionBackend):
# MLA
metadata.max_enc_len_this_time = metadata.set_max_lengths[1]
metadata.max_dec_len_this_time = metadata.set_max_lengths[2]
forward_meta.max_enc_len_this_time = metadata.set_max_lengths[1]
forward_meta.max_dec_len_this_time = metadata.set_max_lengths[2]
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
@@ -375,9 +378,6 @@ class MLAAttentionBackend(AttentionBackend):
speculate_decoder = self.speculative_method is not None
speculate_max_tokens = self.speculate_max_draft_token_num
decode_stage = forward_meta.is_decode_batch
prefill_stage = not (forward_meta.is_decode_batch)
if self.use_pd_disaggregation:
metadata.kv_signal_data_list[
layer.layer_id] = init_signal_layerwise(
@@ -387,8 +387,7 @@ class MLAAttentionBackend(AttentionBackend):
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(
forward_meta, 'caches') else None
if prefill_stage:
# 写入缓存
if k is not None:
prefill_mla_write_cache(
compressed_kv,
k_pe,
@@ -419,8 +418,7 @@ class MLAAttentionBackend(AttentionBackend):
return fmha_out
# Decode
if decode_stage:
# mla写入缓存
if k is None:
decode_mla_write_cache(
compressed_kv,
k_pe,

View File

@@ -317,10 +317,7 @@ class DeepseekV3MLAAttention(nn.Layer):
],
dtype=layernorm_out.dtype)
decode_stage = forward_meta.is_decode_batch
prefill_stage = not (forward_meta.is_decode_batch)
if prefill_stage:
if forward_meta.max_enc_len_this_time:
query = self.q_a_proj(layernorm_out)
query = self.q_a_layernorm(query)
query = self.q_b_proj(query)
@@ -370,8 +367,7 @@ class DeepseekV3MLAAttention(nn.Layer):
fmha_out_prefill.dtype)
fmha_out = fmha_out + fmha_out_prefill
if decode_stage:
if forward_meta.max_dec_len_this_time:
query = self.q_a_proj(layernorm_out)
query = self.q_a_layernorm(query)
ln_out_or_q_c = query
@@ -554,28 +550,6 @@ class DeepSeekV3Model(nn.Layer):
prefix="deepseek_v3.norm",
)
def pre_process(self, forward_meta):
"""
"""
seq_lens_encoder = forward_meta.seq_lens_encoder
seq_lens_decoder = forward_meta.seq_lens_decoder
seq_lens_this_time = forward_meta.seq_lens_this_time
position_ids_shape = paddle.sum(seq_lens_this_time)
position_ids = paddle.empty(shape=position_ids_shape,
dtype=seq_lens_encoder.dtype)
mask_encoder_batch = paddle.empty(
shape=position_ids_shape,
dtype=seq_lens_encoder.dtype).unsqueeze(1)
get_position_ids_and_mask_encoder_batch(seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
position_ids,
mask_encoder_batch)
return position_ids, mask_encoder_batch
def load_state_dict(self, state_dict):
"""
Load model parameters from a given state dictionary.
@@ -590,13 +564,13 @@ class DeepSeekV3Model(nn.Layer):
self,
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
position_ids: paddle.Tensor,
mask_encoder_batch: paddle.Tensor,
):
"""
"""
hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding)
position_ids, mask_encoder_batch = self.pre_process(forward_meta)
residual = None
for i in range(self.num_layers):
hidden_states, residual = self.decoder_layers[i](
@@ -650,6 +624,27 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
logits[:, self.ori_vocab_size:] = -float("inf")
return logits
def pre_process(self, forward_meta):
"""
"""
seq_lens_encoder = forward_meta.seq_lens_encoder
seq_lens_decoder = forward_meta.seq_lens_decoder
seq_lens_this_time = forward_meta.seq_lens_this_time
position_ids_shape = paddle.sum(seq_lens_this_time)
position_ids = paddle.empty(shape=position_ids_shape,
dtype=seq_lens_encoder.dtype)
mask_encoder_batch = paddle.empty(
shape=position_ids_shape,
dtype=seq_lens_encoder.dtype).unsqueeze(1)
get_position_ids_and_mask_encoder_batch(seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
position_ids,
mask_encoder_batch)
return position_ids, mask_encoder_batch
def forward(
self,
ids_remove_padding: paddle.Tensor,
@@ -657,7 +652,9 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
):
"""
"""
hidden_states = self.model(ids_remove_padding, forward_meta)
position_ids, mask_encoder_batch = self.pre_process(forward_meta)
hidden_states = self.model(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta,
position_ids=position_ids, mask_encoder_batch=mask_encoder_batch)
return hidden_states