mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
fix model_weights_signal (#4092)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* fix model_weights_signal
This commit is contained in:
@@ -248,6 +248,11 @@ class PaddleDisWorkerProc:
|
|||||||
create=False,
|
create=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _broadcast_model_weights_signal(self, src: int, group) -> int:
|
||||||
|
model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32")
|
||||||
|
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
|
||||||
|
return model_weights_signal_tensor.item()
|
||||||
|
|
||||||
def event_loop_normal(self) -> None:
|
def event_loop_normal(self) -> None:
|
||||||
"""Main event loop for Paddle Distrubuted Workers.
|
"""Main event loop for Paddle Distrubuted Workers.
|
||||||
TODO(gongshaotian): support remote calling of functions that control worker.
|
TODO(gongshaotian): support remote calling of functions that control worker.
|
||||||
@@ -257,15 +262,19 @@ class PaddleDisWorkerProc:
|
|||||||
req_ids = []
|
req_ids = []
|
||||||
num_running_requests = 0
|
num_running_requests = 0
|
||||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||||
self.model_weights_signal = paddle.zeros([1], dtype=paddle.int32)
|
self.model_weights_signal = np.zeros([1], dtype=np.int32)
|
||||||
while True:
|
while True:
|
||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
if self.model_weights_status.value[0] != 0:
|
if self.model_weights_status.value[0] != 0:
|
||||||
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
|
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
|
||||||
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
|
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
|
||||||
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.ep_group)
|
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
|
||||||
if self.fd_config.load_config.dynamic_load_weight:
|
src=0, group=self.parallel_config.ep_group
|
||||||
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.tp_group)
|
)
|
||||||
|
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.tensor_parallel_size > 1:
|
||||||
|
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
|
||||||
|
src=0, group=self.parallel_config.tp_group
|
||||||
|
)
|
||||||
|
|
||||||
self.insert_step = False
|
self.insert_step = False
|
||||||
req_dicts = None
|
req_dicts = None
|
||||||
@@ -293,7 +302,9 @@ class PaddleDisWorkerProc:
|
|||||||
else:
|
else:
|
||||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||||
if self.model_weights_signal[0] != 0:
|
if self.model_weights_signal[0] != 0:
|
||||||
logger.info(f"Rank: {self.local_rank} has updated parameters.")
|
logger.info(
|
||||||
|
f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"
|
||||||
|
)
|
||||||
from fastdeploy.rl.dynamic_weight_manager import (
|
from fastdeploy.rl.dynamic_weight_manager import (
|
||||||
DynamicWeightManager,
|
DynamicWeightManager,
|
||||||
)
|
)
|
||||||
@@ -305,6 +316,7 @@ class PaddleDisWorkerProc:
|
|||||||
self.parallel_config.engine_worker_queue_port,
|
self.parallel_config.engine_worker_queue_port,
|
||||||
)
|
)
|
||||||
self.model_weights_signal[0] = 0
|
self.model_weights_signal[0] = 0
|
||||||
|
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
|
||||||
|
|
||||||
if self.exist_task_signal.value[0] == 1 or self.task_queue.read_finish_flag.get() == 1:
|
if self.exist_task_signal.value[0] == 1 or self.task_queue.read_finish_flag.get() == 1:
|
||||||
logger.info(f"Rank: {self.local_rank} Detected new requests.")
|
logger.info(f"Rank: {self.local_rank} Detected new requests.")
|
||||||
|
Reference in New Issue
Block a user