mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[BugFix] Fix low prediction accuracy of deepseekv3 (#2798)
This commit is contained in:
@@ -41,7 +41,8 @@ from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
||||
from fastdeploy.model_executor.layers.attention.utils import \
|
||||
init_rank_and_device_id
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
||||
|
||||
@@ -185,6 +186,8 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
# MLA
|
||||
metadata.max_enc_len_this_time = metadata.set_max_lengths[1]
|
||||
metadata.max_dec_len_this_time = metadata.set_max_lengths[2]
|
||||
forward_meta.max_enc_len_this_time = metadata.set_max_lengths[1]
|
||||
forward_meta.max_dec_len_this_time = metadata.set_max_lengths[2]
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
@@ -375,9 +378,6 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
speculate_decoder = self.speculative_method is not None
|
||||
speculate_max_tokens = self.speculate_max_draft_token_num
|
||||
|
||||
decode_stage = forward_meta.is_decode_batch
|
||||
prefill_stage = not (forward_meta.is_decode_batch)
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
@@ -387,8 +387,7 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(
|
||||
forward_meta, 'caches') else None
|
||||
|
||||
if prefill_stage:
|
||||
# 写入缓存
|
||||
if k is not None:
|
||||
prefill_mla_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
@@ -419,8 +418,7 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
return fmha_out
|
||||
|
||||
# Decode
|
||||
if decode_stage:
|
||||
# mla写入缓存
|
||||
if k is None:
|
||||
decode_mla_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
|
@@ -317,10 +317,7 @@ class DeepseekV3MLAAttention(nn.Layer):
|
||||
],
|
||||
dtype=layernorm_out.dtype)
|
||||
|
||||
decode_stage = forward_meta.is_decode_batch
|
||||
prefill_stage = not (forward_meta.is_decode_batch)
|
||||
|
||||
if prefill_stage:
|
||||
if forward_meta.max_enc_len_this_time:
|
||||
query = self.q_a_proj(layernorm_out)
|
||||
query = self.q_a_layernorm(query)
|
||||
query = self.q_b_proj(query)
|
||||
@@ -370,8 +367,7 @@ class DeepseekV3MLAAttention(nn.Layer):
|
||||
fmha_out_prefill.dtype)
|
||||
|
||||
fmha_out = fmha_out + fmha_out_prefill
|
||||
|
||||
if decode_stage:
|
||||
if forward_meta.max_dec_len_this_time:
|
||||
query = self.q_a_proj(layernorm_out)
|
||||
query = self.q_a_layernorm(query)
|
||||
ln_out_or_q_c = query
|
||||
@@ -554,28 +550,6 @@ class DeepSeekV3Model(nn.Layer):
|
||||
prefix="deepseek_v3.norm",
|
||||
)
|
||||
|
||||
def pre_process(self, forward_meta):
|
||||
"""
|
||||
"""
|
||||
seq_lens_encoder = forward_meta.seq_lens_encoder
|
||||
seq_lens_decoder = forward_meta.seq_lens_decoder
|
||||
seq_lens_this_time = forward_meta.seq_lens_this_time
|
||||
position_ids_shape = paddle.sum(seq_lens_this_time)
|
||||
|
||||
position_ids = paddle.empty(shape=position_ids_shape,
|
||||
dtype=seq_lens_encoder.dtype)
|
||||
mask_encoder_batch = paddle.empty(
|
||||
shape=position_ids_shape,
|
||||
dtype=seq_lens_encoder.dtype).unsqueeze(1)
|
||||
|
||||
get_position_ids_and_mask_encoder_batch(seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
position_ids,
|
||||
mask_encoder_batch)
|
||||
|
||||
return position_ids, mask_encoder_batch
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""
|
||||
Load model parameters from a given state dictionary.
|
||||
@@ -590,13 +564,13 @@ class DeepSeekV3Model(nn.Layer):
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
forward_meta: ForwardMeta,
|
||||
position_ids: paddle.Tensor,
|
||||
mask_encoder_batch: paddle.Tensor,
|
||||
):
|
||||
"""
|
||||
"""
|
||||
hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding)
|
||||
|
||||
position_ids, mask_encoder_batch = self.pre_process(forward_meta)
|
||||
|
||||
residual = None
|
||||
for i in range(self.num_layers):
|
||||
hidden_states, residual = self.decoder_layers[i](
|
||||
@@ -650,6 +624,27 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
||||
logits[:, self.ori_vocab_size:] = -float("inf")
|
||||
return logits
|
||||
|
||||
def pre_process(self, forward_meta):
|
||||
"""
|
||||
"""
|
||||
seq_lens_encoder = forward_meta.seq_lens_encoder
|
||||
seq_lens_decoder = forward_meta.seq_lens_decoder
|
||||
seq_lens_this_time = forward_meta.seq_lens_this_time
|
||||
position_ids_shape = paddle.sum(seq_lens_this_time)
|
||||
position_ids = paddle.empty(shape=position_ids_shape,
|
||||
dtype=seq_lens_encoder.dtype)
|
||||
mask_encoder_batch = paddle.empty(
|
||||
shape=position_ids_shape,
|
||||
dtype=seq_lens_encoder.dtype).unsqueeze(1)
|
||||
|
||||
get_position_ids_and_mask_encoder_batch(seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
position_ids,
|
||||
mask_encoder_batch)
|
||||
|
||||
return position_ids, mask_encoder_batch
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
@@ -657,7 +652,9 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
||||
):
|
||||
"""
|
||||
"""
|
||||
hidden_states = self.model(ids_remove_padding, forward_meta)
|
||||
position_ids, mask_encoder_batch = self.pre_process(forward_meta)
|
||||
hidden_states = self.model(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta,
|
||||
position_ids=position_ids, mask_encoder_batch=mask_encoder_batch)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user