diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 4f5199cfa..a72518566 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -41,7 +41,8 @@ from fastdeploy.config import FDConfig from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, AttentionMetadata) -from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id +from fastdeploy.model_executor.layers.attention.utils import \ + init_rank_and_device_id from fastdeploy.worker.forward_meta import ForwardMeta @@ -185,6 +186,8 @@ class MLAAttentionBackend(AttentionBackend): # MLA metadata.max_enc_len_this_time = metadata.set_max_lengths[1] metadata.max_dec_len_this_time = metadata.set_max_lengths[2] + forward_meta.max_enc_len_this_time = metadata.set_max_lengths[1] + forward_meta.max_dec_len_this_time = metadata.set_max_lengths[2] # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers @@ -375,9 +378,6 @@ class MLAAttentionBackend(AttentionBackend): speculate_decoder = self.speculative_method is not None speculate_max_tokens = self.speculate_max_draft_token_num - decode_stage = forward_meta.is_decode_batch - prefill_stage = not (forward_meta.is_decode_batch) - if self.use_pd_disaggregation: metadata.kv_signal_data_list[ layer.layer_id] = init_signal_layerwise( @@ -387,8 +387,7 @@ class MLAAttentionBackend(AttentionBackend): latent_cache = forward_meta.caches[layer.layer_id] if hasattr( forward_meta, 'caches') else None - if prefill_stage: - # 写入缓存 + if k is not None: prefill_mla_write_cache( compressed_kv, k_pe, @@ -419,8 +418,7 @@ class MLAAttentionBackend(AttentionBackend): return fmha_out # Decode - if decode_stage: - # mla写入缓存 + if k is None: decode_mla_write_cache( compressed_kv, k_pe, diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 73997c2ac..eac6ec9ec 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -317,10 +317,7 @@ class DeepseekV3MLAAttention(nn.Layer): ], dtype=layernorm_out.dtype) - decode_stage = forward_meta.is_decode_batch - prefill_stage = not (forward_meta.is_decode_batch) - - if prefill_stage: + if forward_meta.max_enc_len_this_time: query = self.q_a_proj(layernorm_out) query = self.q_a_layernorm(query) query = self.q_b_proj(query) @@ -370,8 +367,7 @@ class DeepseekV3MLAAttention(nn.Layer): fmha_out_prefill.dtype) fmha_out = fmha_out + fmha_out_prefill - - if decode_stage: + if forward_meta.max_dec_len_this_time: query = self.q_a_proj(layernorm_out) query = self.q_a_layernorm(query) ln_out_or_q_c = query @@ -554,28 +550,6 @@ class DeepSeekV3Model(nn.Layer): prefix="deepseek_v3.norm", ) - def pre_process(self, forward_meta): - """ - """ - seq_lens_encoder = forward_meta.seq_lens_encoder - seq_lens_decoder = forward_meta.seq_lens_decoder - seq_lens_this_time = forward_meta.seq_lens_this_time - position_ids_shape = paddle.sum(seq_lens_this_time) - - position_ids = paddle.empty(shape=position_ids_shape, - dtype=seq_lens_encoder.dtype) - mask_encoder_batch = paddle.empty( - shape=position_ids_shape, - dtype=seq_lens_encoder.dtype).unsqueeze(1) - - get_position_ids_and_mask_encoder_batch(seq_lens_encoder, - seq_lens_decoder, - seq_lens_this_time, - position_ids, - mask_encoder_batch) - - return position_ids, mask_encoder_batch - def load_state_dict(self, state_dict): """ Load model parameters from a given state dictionary. @@ -590,13 +564,13 @@ class DeepSeekV3Model(nn.Layer): self, ids_remove_padding: paddle.Tensor, forward_meta: ForwardMeta, + position_ids: paddle.Tensor, + mask_encoder_batch: paddle.Tensor, ): """ """ hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding) - position_ids, mask_encoder_batch = self.pre_process(forward_meta) - residual = None for i in range(self.num_layers): hidden_states, residual = self.decoder_layers[i]( @@ -650,6 +624,27 @@ class DeepseekV3ForCausalLM(ModelForCasualLM): logits[:, self.ori_vocab_size:] = -float("inf") return logits + def pre_process(self, forward_meta): + """ + """ + seq_lens_encoder = forward_meta.seq_lens_encoder + seq_lens_decoder = forward_meta.seq_lens_decoder + seq_lens_this_time = forward_meta.seq_lens_this_time + position_ids_shape = paddle.sum(seq_lens_this_time) + position_ids = paddle.empty(shape=position_ids_shape, + dtype=seq_lens_encoder.dtype) + mask_encoder_batch = paddle.empty( + shape=position_ids_shape, + dtype=seq_lens_encoder.dtype).unsqueeze(1) + + get_position_ids_and_mask_encoder_batch(seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + position_ids, + mask_encoder_batch) + + return position_ids, mask_encoder_batch + def forward( self, ids_remove_padding: paddle.Tensor, @@ -657,7 +652,9 @@ class DeepseekV3ForCausalLM(ModelForCasualLM): ): """ """ - hidden_states = self.model(ids_remove_padding, forward_meta) + position_ids, mask_encoder_batch = self.pre_process(forward_meta) + hidden_states = self.model(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta, + position_ids=position_ids, mask_encoder_batch=mask_encoder_batch) return hidden_states