mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 11:56:44 +08:00 
			
		
		
		
	[Graph Optimization][Speculative Decoding] Fix the bug of CUDAGraph + MTP + EP  (#4430)
	
		
			
	
		
	
	
		
	
		
			Some checks failed
		
		
	
	
		
			
				
	
				CE Compile Job / ce_job_pre_check (push) Has been cancelled
				
			
		
			
				
	
				CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
				
			
		
			
				
	
				CE Compile Job / FD-Clone-Linux (push) Has been cancelled
				
			
		
			
				
	
				CE Compile Job / Show Code Archive Output (push) Has been cancelled
				
			
		
			
				
	
				CE Compile Job / BUILD_SM8090 (push) Has been cancelled
				
			
		
			
				
	
				CE Compile Job / BUILD_SM8689 (push) Has been cancelled
				
			
		
			
				
	
				CE Compile Job / CE_UPLOAD (push) Has been cancelled
				
			
		
		
	
	
				
					
				
			
		
			Some checks failed
		
		
	
	CE Compile Job / ce_job_pre_check (push) Has been cancelled
				
			CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
				
			CE Compile Job / FD-Clone-Linux (push) Has been cancelled
				
			CE Compile Job / Show Code Archive Output (push) Has been cancelled
				
			CE Compile Job / BUILD_SM8090 (push) Has been cancelled
				
			CE Compile Job / BUILD_SM8689 (push) Has been cancelled
				
			CE Compile Job / CE_UPLOAD (push) Has been cancelled
				
			* Fix MTP dummy run bug * Target Model and Draft Model using the same flag * aovid moe bug in cudagraph padding * In mtp replace use_cudagraph as step_use_cudagraph
This commit is contained in:
		| @@ -496,12 +496,12 @@ std::vector<paddle::Tensor> AppendAttention( | |||||||
|   paddle::Tensor fmha_out; |   paddle::Tensor fmha_out; | ||||||
|   if (out_linear_in_scale > 0.0) { |   if (out_linear_in_scale > 0.0) { | ||||||
|     if (fabs(quant_max_bound - 127.0f) < 0.000001) { |     if (fabs(quant_max_bound - 127.0f) < 0.000001) { | ||||||
|       fmha_out = GetEmptyTensor( |       fmha_out = paddle::zeros( | ||||||
|         {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, |         {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, | ||||||
|         paddle::DataType::INT8, |         paddle::DataType::INT8, | ||||||
|         qkv.place()); |         qkv.place()); | ||||||
|     } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { |     } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { | ||||||
|       fmha_out = GetEmptyTensor( |       fmha_out = paddle::zeros( | ||||||
|         {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, |         {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, | ||||||
|         paddle::DataType::FLOAT8_E4M3FN, |         paddle::DataType::FLOAT8_E4M3FN, | ||||||
|         qkv.place()); |         qkv.place()); | ||||||
| @@ -509,7 +509,7 @@ std::vector<paddle::Tensor> AppendAttention( | |||||||
|       PD_THROW("Only supported attr of quant_max_bound in ['127', '448']."); |       PD_THROW("Only supported attr of quant_max_bound in ['127', '448']."); | ||||||
|     } |     } | ||||||
|   } else { |   } else { | ||||||
|     fmha_out = GetEmptyTensor( |     fmha_out = paddle::zeros( | ||||||
|         {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, |         {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, | ||||||
|         dtype_id, |         dtype_id, | ||||||
|         qkv.place()); |         qkv.place()); | ||||||
|   | |||||||
| @@ -79,7 +79,7 @@ class MTPProposer(Proposer): | |||||||
|         self._init_model_inputs() |         self._init_model_inputs() | ||||||
|  |  | ||||||
|         # CUDA Graph |         # CUDA Graph | ||||||
|         self.use_cudagraph = False  # self.graph_opt_config.use_cudagraph |         self.use_cudagraph = False  # TODO(gongshaotian): Use Target Model flag | ||||||
|         self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) |         self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) | ||||||
|         self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes |         self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes | ||||||
|  |  | ||||||
| @@ -117,6 +117,9 @@ class MTPProposer(Proposer): | |||||||
|             self.parallel_config.max_model_len - max_dec_len, |             self.parallel_config.max_model_len - max_dec_len, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |         if self.fd_config.parallel_config.enable_expert_parallel: | ||||||
|  |             input_length = min(input_length, 32) | ||||||
|  |  | ||||||
|         block_num = ( |         block_num = ( | ||||||
|             input_length + self.cache_config.block_size - 1 |             input_length + self.cache_config.block_size - 1 | ||||||
|         ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num |         ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num | ||||||
| @@ -541,7 +544,7 @@ class MTPProposer(Proposer): | |||||||
|         self.model_inputs["not_need_stop"][0] = True |         self.model_inputs["not_need_stop"][0] = True | ||||||
|         self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer |         self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer | ||||||
|  |  | ||||||
|     def _initialize_forward_meta(self): |     def _initialize_forward_meta(self, step_use_cudagraph: bool = False): | ||||||
|         """ |         """ | ||||||
|         Initialize forward meta and attention meta data |         Initialize forward meta and attention meta data | ||||||
|         """ |         """ | ||||||
| @@ -569,23 +572,8 @@ class MTPProposer(Proposer): | |||||||
|         for attn_backend in self.attn_backends: |         for attn_backend in self.attn_backends: | ||||||
|             attn_backend.init_attention_metadata(self.forward_meta) |             attn_backend.init_attention_metadata(self.forward_meta) | ||||||
|  |  | ||||||
|         # Update Batch type for cuda graph |         # TODO(gongshaotian): Use CUDAGraph with Draft Model | ||||||
|         only_decode_batch = True |         self.forward_meta.step_use_cudagraph = step_use_cudagraph and self.use_cudagraph | ||||||
|         prefill_exists = None |  | ||||||
|  |  | ||||||
|         # Mix ep in single node |  | ||||||
|         if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": |  | ||||||
|             only_decode_batch_list = [] |  | ||||||
|             prefill_exists = self.exist_prefill() |  | ||||||
|             paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) |  | ||||||
|             only_decode_batch = all(only_decode_batch_list) |  | ||||||
|             self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill" |  | ||||||
|  |  | ||||||
|         self.forward_meta.step_use_cudagraph = ( |  | ||||||
|             self.use_cudagraph |  | ||||||
|             and only_decode_batch |  | ||||||
|             and not (prefill_exists if prefill_exists is not None else self.exist_prefill()) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def exist_prefill(self): |     def exist_prefill(self): | ||||||
|         """ |         """ | ||||||
| @@ -671,9 +659,12 @@ class MTPProposer(Proposer): | |||||||
|                 self.parallel_config.use_ep, |                 self.parallel_config.use_ep, | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def _propose(self): |     def _propose(self, step_use_cudagraph: bool = False): | ||||||
|         """ |         """ | ||||||
|         Main process for MTP inference |         Main process for MTP inference. | ||||||
|  |         Args: | ||||||
|  |         step_use_cudagraph: bool | ||||||
|  |             Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP. | ||||||
|         """ |         """ | ||||||
|         for substep in range(self.num_model_steps): |         for substep in range(self.num_model_steps): | ||||||
|             if self.model_inputs["not_need_stop"]: |             if self.model_inputs["not_need_stop"]: | ||||||
| @@ -697,7 +688,7 @@ class MTPProposer(Proposer): | |||||||
|  |  | ||||||
|                 # Initialize forward meta data |                 # Initialize forward meta data | ||||||
|                 self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) |                 self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) | ||||||
|                 self.model_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) |                 self.model_inputs["batch_id_per_token"][:] = -1 | ||||||
|                 self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) |                 self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) | ||||||
|                 self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) |                 self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) | ||||||
|                 # for speculative decoding |                 # for speculative decoding | ||||||
| @@ -705,7 +696,8 @@ class MTPProposer(Proposer): | |||||||
|                 self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False) |                 self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False) | ||||||
|  |  | ||||||
|                 # Initialize forward meta data |                 # Initialize forward meta data | ||||||
|                 self._initialize_forward_meta() |                 self._initialize_forward_meta(step_use_cudagraph=step_use_cudagraph) | ||||||
|  |                 self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False) | ||||||
|  |  | ||||||
|                 # Padding inputs for cuda graph |                 # Padding inputs for cuda graph | ||||||
|                 self.padding_cudagraph_inputs() |                 self.padding_cudagraph_inputs() | ||||||
| @@ -733,7 +725,7 @@ class MTPProposer(Proposer): | |||||||
|                     previous_hidden_states=self.model_inputs["target_hidden_states"], |                     previous_hidden_states=self.model_inputs["target_hidden_states"], | ||||||
|                     forward_meta=self.forward_meta, |                     forward_meta=self.forward_meta, | ||||||
|                 ) |                 ) | ||||||
|                 if self.use_cudagraph: |                 if self.forward_meta.step_use_cudagraph: | ||||||
|                     model_output = model_output[: self.real_token_num] |                     model_output = model_output[: self.real_token_num] | ||||||
|  |  | ||||||
|                 hidden_states = rebuild_padding( |                 hidden_states = rebuild_padding( | ||||||
| @@ -861,10 +853,10 @@ class MTPProposer(Proposer): | |||||||
|         self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda() |         self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda() | ||||||
|         self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() |         self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() | ||||||
|  |  | ||||||
|     def _run_impl(self, full_hidden_states): |     def _run_impl(self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False): | ||||||
|         """""" |         """Execute Draft Model""" | ||||||
|         self._prepare_inputs(full_hidden_states) |         self._prepare_inputs(full_hidden_states) | ||||||
|         self._propose() |         self._propose(step_use_cudagraph=step_use_cudagraph) | ||||||
|         self._update_status() |         self._update_status() | ||||||
|         if self.hybrid_mode: |         if self.hybrid_mode: | ||||||
|             self._extend_draft_token_with_ngram_match() |             self._extend_draft_token_with_ngram_match() | ||||||
| @@ -881,7 +873,7 @@ class MTPProposer(Proposer): | |||||||
|         # In init_attention_metadata, the decode buffer has already been cleared |         # In init_attention_metadata, the decode buffer has already been cleared | ||||||
|  |  | ||||||
|         # To adapt to CUDA Graph, keep the forward pass at the maximum batch size. |         # To adapt to CUDA Graph, keep the forward pass at the maximum batch size. | ||||||
|         if self.use_cudagraph: |         if self.forward_meta.step_use_cudagraph: | ||||||
|             self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer |             self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer | ||||||
|             self.real_token_num = self.forward_meta.ids_remove_padding.shape[0] |             self.real_token_num = self.forward_meta.ids_remove_padding.shape[0] | ||||||
|         return |         return | ||||||
|   | |||||||
| @@ -1253,7 +1253,9 @@ class GPUModelRunner(ModelRunnerBase): | |||||||
|  |  | ||||||
|             if self.speculative_decoding: |             if self.speculative_decoding: | ||||||
|                 if self.speculative_method == "mtp": |                 if self.speculative_method == "mtp": | ||||||
|                     self.proposer.run(full_hidden_states=model_output) |                     self.proposer.run( | ||||||
|  |                         full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph | ||||||
|  |                     ) | ||||||
|                 else: |                 else: | ||||||
|                     self.proposer.run(share_inputs=self.share_inputs) |                     self.proposer.run(share_inputs=self.share_inputs) | ||||||
|  |  | ||||||
| @@ -1600,7 +1602,9 @@ class GPUModelRunner(ModelRunnerBase): | |||||||
|         # 6. Speculative decode |         # 6. Speculative decode | ||||||
|         if self.speculative_decoding: |         if self.speculative_decoding: | ||||||
|             if self.speculative_method == "mtp": |             if self.speculative_method == "mtp": | ||||||
|                 self.proposer.run(full_hidden_states=model_output) |                 self.proposer.run( | ||||||
|  |                     full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph | ||||||
|  |                 ) | ||||||
|             else: |             else: | ||||||
|                 self.proposer.run(share_inputs=self.share_inputs) |                 self.proposer.run(share_inputs=self.share_inputs) | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 RAM
					RAM