diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 4d75b03b9..e4b44f477 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -27,6 +27,9 @@ from paddleformers.utils.log import logger from fastdeploy.config import FDConfig from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.graph_optimization.decorator import ( + support_graph_optimization, +) from fastdeploy.model_executor.layers.activation import SiluAndMul from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding @@ -501,6 +504,7 @@ class DeepSeekV3DecoderLayer(nn.Layer): return hidden_states, residual +@support_graph_optimization class DeepSeekV3Model(nn.Layer): """ DeepSeekV3Model @@ -596,6 +600,10 @@ class DeepseekV3ForCausalLM(ModelForCasualLM): num_embeddings=fd_config.model_config.vocab_size, prefix="lm_head", ) + self.position_ids_buffer = paddle.empty([fd_config.parallel_config.max_num_batched_tokens], dtype=paddle.int32) + self.mask_encoder_batch_buffer = paddle.empty( + [fd_config.parallel_config.max_num_batched_tokens, 1], dtype=paddle.int32 + ) @classmethod def name(cls): @@ -622,9 +630,10 @@ class DeepseekV3ForCausalLM(ModelForCasualLM): seq_lens_encoder = forward_meta.seq_lens_encoder seq_lens_decoder = forward_meta.seq_lens_decoder seq_lens_this_time = forward_meta.seq_lens_this_time - position_ids_shape = paddle.sum(seq_lens_this_time) - position_ids = paddle.empty(shape=position_ids_shape, dtype=seq_lens_encoder.dtype) - mask_encoder_batch = paddle.empty(shape=position_ids_shape, dtype=seq_lens_encoder.dtype).unsqueeze(1) + + current_total_tokens = paddle.sum(seq_lens_this_time) + position_ids = self.position_ids_buffer[:current_total_tokens] + mask_encoder_batch = self.mask_encoder_batch_buffer[:current_total_tokens] get_position_ids_and_mask_encoder_batch( seq_lens_encoder, @@ -633,7 +642,6 @@ class DeepseekV3ForCausalLM(ModelForCasualLM): position_ids, mask_encoder_batch, ) - return position_ids, mask_encoder_batch def forward(