mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[BugFix] Fix low prediction accuracy of deepseekv3 (#2798)
This commit is contained in:
@@ -41,7 +41,8 @@ from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
||||
from fastdeploy.model_executor.layers.attention.utils import \
|
||||
init_rank_and_device_id
|
||||
from fastdeploy.worker.forward_meta import ForwardMeta
|
||||
|
||||
|
||||
@@ -185,6 +186,8 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
# MLA
|
||||
metadata.max_enc_len_this_time = metadata.set_max_lengths[1]
|
||||
metadata.max_dec_len_this_time = metadata.set_max_lengths[2]
|
||||
forward_meta.max_enc_len_this_time = metadata.set_max_lengths[1]
|
||||
forward_meta.max_dec_len_this_time = metadata.set_max_lengths[2]
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
@@ -375,9 +378,6 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
speculate_decoder = self.speculative_method is not None
|
||||
speculate_max_tokens = self.speculate_max_draft_token_num
|
||||
|
||||
decode_stage = forward_meta.is_decode_batch
|
||||
prefill_stage = not (forward_meta.is_decode_batch)
|
||||
|
||||
if self.use_pd_disaggregation:
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
@@ -387,8 +387,7 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(
|
||||
forward_meta, 'caches') else None
|
||||
|
||||
if prefill_stage:
|
||||
# 写入缓存
|
||||
if k is not None:
|
||||
prefill_mla_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
@@ -419,8 +418,7 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
return fmha_out
|
||||
|
||||
# Decode
|
||||
if decode_stage:
|
||||
# mla写入缓存
|
||||
if k is None:
|
||||
decode_mla_write_cache(
|
||||
compressed_kv,
|
||||
k_pe,
|
||||
|
Reference in New Issue
Block a user