mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature] support model weight update in ep (#3765)
* support model weight update in ep * support model weight update in ep * support model weight update in ep * support model weight update in ep * Update fused_moe_backend_base.py * Update worker_process.py * Update worker_process.py * Update dynamic_weight_manager.py
This commit is contained in:
@@ -254,27 +254,26 @@ class PaddleDisWorkerProc:
|
||||
"""
|
||||
# Currently, only support single node
|
||||
self.nnode = int((self.parallel_config.tensor_parallel_size + 7) // 8)
|
||||
mp_num_per_node = self.parallel_config.tensor_parallel_size // self.nnode
|
||||
req_ids = []
|
||||
num_running_requests = 0
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
while True:
|
||||
if self.local_rank == 0:
|
||||
if self.model_weights_status.value[0] != 0:
|
||||
self.exist_task_signal.value[0] = 2
|
||||
else:
|
||||
self.exist_task_signal.value[0] = 0
|
||||
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
# Synchronize before updating weights
|
||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||
self.model_weights_signal = paddle.zeros([1], dtype=paddle.int32)
|
||||
while True:
|
||||
if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
|
||||
if self.model_weights_status.value[0] != 0:
|
||||
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
|
||||
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
|
||||
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.ep_group)
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.tp_group)
|
||||
|
||||
self.insert_step = False
|
||||
req_dicts = None
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time())
|
||||
|
||||
# The first worker detects whether there are tasks in the task queue
|
||||
if self.local_rank % mp_num_per_node == 0:
|
||||
if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
|
||||
if self.task_queue.num_tasks() > 0:
|
||||
# VL only support 1 batch to prefill
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER or not (
|
||||
@@ -290,16 +289,24 @@ class PaddleDisWorkerProc:
|
||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
if self.exist_task_signal.value[0] == 2:
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
paddle.distributed.barrier(self.parallel_config.ep_group)
|
||||
else:
|
||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||
if self.model_weights_signal[0] != 0:
|
||||
logger.info(f"Rank: {self.local_rank} has updated parameters.")
|
||||
from fastdeploy.rl.dynamic_weight_manager import (
|
||||
DynamicWeightManager,
|
||||
)
|
||||
|
||||
self.model_weights_status.value[0] = self.model_weights_signal[0]
|
||||
DynamicWeightManager.check_model_weights_status(
|
||||
self.model_weights_status,
|
||||
# model_weights_signal
|
||||
self.worker.model_runner,
|
||||
self.parallel_config.engine_pid,
|
||||
self.parallel_config.engine_worker_queue_port,
|
||||
)
|
||||
self.model_weights_signal[0] = 0
|
||||
|
||||
if self.exist_task_signal.value[0] == 1 or self.task_queue.read_finish_flag.get() == 1:
|
||||
logger.info(f"Rank: {self.local_rank} Detected new requests.")
|
||||
|
Reference in New Issue
Block a user