mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 20:02:53 +08:00 
			
		
		
		
	[NewFeture]add ep rollout model init and update/clear ep buffer (#3927)
* add ep rollout model init && add deep update/clear * fix test
This commit is contained in:
		| @@ -48,7 +48,7 @@ class DynamicWeightManager: | ||||
|  | ||||
|         logger.info( | ||||
|             f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, " | ||||
|             f" rank={self.rank}, ranks={self.nranks}" | ||||
|             f" tp rank={self.rank}, dp rank={fd_config.parallel_config.local_data_parallel_id}, ep rank={fd_config.parallel_config.expert_parallel_rank}, ranks={self.nranks}, " | ||||
|         ) | ||||
|  | ||||
|     @paddle.no_grad() | ||||
| @@ -63,11 +63,21 @@ class DynamicWeightManager: | ||||
|         start_time = time.perf_counter() | ||||
|         paddle.device.cuda.empty_cache() | ||||
|  | ||||
|         # step1 : restart paddle process group | ||||
|         if not self.first_load: | ||||
|             paddle.distributed.restart_process_group(self.parallel_config.tp_group) | ||||
|             if self.parallel_config.enable_expert_parallel: | ||||
|                 paddle.distributed.restart_process_group(self.parallel_config.ep_group) | ||||
|  | ||||
|         # step2 : recreat deepep buffer when enable expert parallel | ||||
|         if self.parallel_config.enable_expert_parallel and not self.first_load: | ||||
|             from fastdeploy.model_executor.layers.moe.ep import DeepEPBufferManager | ||||
|  | ||||
|             DeepEPBufferManager.recreate_buffer() | ||||
|             # ep barrier | ||||
|             paddle.distributed.barrier(self.parallel_config.ep_group) | ||||
|  | ||||
|         # step3 : update model weight | ||||
|         strategy_handlers = { | ||||
|             "ipc_snapshot": self._update_ipc_snapshot, | ||||
|             "ipc": self._update_ipc, | ||||
| @@ -79,6 +89,10 @@ class DynamicWeightManager: | ||||
|             raise ValueError(f"Unsupported strategy: {self.load_config.load_strategy}") | ||||
|  | ||||
|         logger.info(f"Update parameters in {time.perf_counter()-start_time:.2f}s") | ||||
|         # steps in the runner | ||||
|         # step 4: reinitialze kv_cache | ||||
|         # step 5: recapture CUDAGraph | ||||
|         # step 6: update weight status signal | ||||
|  | ||||
|     def _update_ipc_snapshot(self): | ||||
|         """Update using IPC snapshot strategy for elastic recovery.""" | ||||
| @@ -106,18 +120,31 @@ class DynamicWeightManager: | ||||
|     def clear_parameters(self, pid: int = 0) -> None: | ||||
|         """Clear all model parameters and free memory.""" | ||||
|         logger.info("start clear paramaters") | ||||
|  | ||||
|         # step1: release deepep buffer | ||||
|         if self.parallel_config.enable_expert_parallel: | ||||
|             from fastdeploy.model_executor.layers.moe.ep import DeepEPBufferManager | ||||
|  | ||||
|             DeepEPBufferManager.clear_buffer() | ||||
|             # ep barrier | ||||
|             paddle.distributed.barrier(self.parallel_config.ep_group) | ||||
|             # shutdown ep group | ||||
|             paddle.distributed.shutdown_process_group(self.parallel_config.ep_group) | ||||
|  | ||||
|         paddle.device.cuda.empty_cache() | ||||
|         # step2: release model weight | ||||
|         for param in self.model.state_dict().values(): | ||||
|             param._clear_data() | ||||
|  | ||||
|         self._verify_parameters("clearance") | ||||
|  | ||||
|         if self.parallel_config.tensor_parallel_size > 1: | ||||
|             # tp barrier | ||||
|             paddle.distributed.barrier(self.parallel_config.tp_group) | ||||
|         paddle.distributed.shutdown_process_group(self.parallel_config.tp_group) | ||||
|         if self.parallel_config.enable_expert_parallel: | ||||
|             paddle.distributed.barrier(self.parallel_config.ep_group) | ||||
|             paddle.distributed.shutdown_process_group(self.parallel_config.ep_group) | ||||
|         paddle.distributed.shutdown_process_group() | ||||
|             # shutdown tp group | ||||
|             paddle.distributed.shutdown_process_group(self.parallel_config.tp_group) | ||||
|         # step3: update model weight signal | ||||
|         # step4: release kv cache in the runner | ||||
|         self._update_shared_status(pid, -2) | ||||
|  | ||||
|     def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str): | ||||
| @@ -146,10 +173,16 @@ class DynamicWeightManager: | ||||
|     def finalize_update(self, pid: int = 0): | ||||
|         """Finalize update process with verification.""" | ||||
|         self._verify_parameters("update") | ||||
|  | ||||
|         if self.parallel_config.tensor_parallel_size > 1: | ||||
|             paddle.distributed.barrier(self.parallel_config.tp_group) | ||||
|  | ||||
|         if self.parallel_config.enable_expert_parallel: | ||||
|             paddle.distributed.barrier(self.parallel_config.ep_group) | ||||
|  | ||||
|         if not self.first_load: | ||||
|             self._update_shared_status(pid, 0) | ||||
|  | ||||
|         self.first_load = False | ||||
|  | ||||
|     def _get_gpu_id(self) -> int: | ||||
|   | ||||
| @@ -24,13 +24,13 @@ class RolloutModelConfig: | ||||
|         max_model_len: int = 32768, | ||||
|         tensor_parallel_size: int = 4, | ||||
|         dynamic_load_weight: bool = True, | ||||
|         load_strategy: str = "ipc_snapshot", | ||||
|         load_strategy: str = "meta", | ||||
|         enable_mm: bool = False, | ||||
|         # Default values for all other parameters | ||||
|         max_num_seqs: int = 34, | ||||
|         total_block_num: int = 2000, | ||||
|         block_size: int = 64, | ||||
|         engine_worker_queue_port: int = 9923, | ||||
|         engine_worker_queue_port: str = "8002", | ||||
|         device_ids: str = "0", | ||||
|         dtype: str = "bfloat16", | ||||
|         enc_dec_block_num: int = 1, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 gaoziyuan
					gaoziyuan