mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[fix] fix ep group all-reduce (#4140)
* [fix] fix ep group all-reduce * [fix] fix clear/update lock not working when workers > 1 * [chore] add preemption triggered info log * [fix] fix code style * fix model_weights_signal (#4092) * fix model_weights_signal --------- Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
@@ -352,8 +352,12 @@ class ParallelConfig:
|
|||||||
)
|
)
|
||||||
dist.collective._set_custom_gid(None)
|
dist.collective._set_custom_gid(None)
|
||||||
# same ep group id
|
# same ep group id
|
||||||
|
# dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
|
||||||
|
# self.ep_group = dist.new_group(range(self.expert_parallel_size))
|
||||||
|
if self.enable_expert_parallel:
|
||||||
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
|
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
|
||||||
self.ep_group = dist.new_group(range(self.expert_parallel_size))
|
self.ep_group = dist.new_group(range(self.expert_parallel_size))
|
||||||
|
dist.collective._set_custom_gid(None)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}."
|
f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}."
|
||||||
)
|
)
|
||||||
|
@@ -120,6 +120,7 @@ class ResourceManagerV1(ResourceManager):
|
|||||||
self._free_blocks(preempted_req)
|
self._free_blocks(preempted_req)
|
||||||
preempted_req.cached_block_num = 0
|
preempted_req.cached_block_num = 0
|
||||||
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
|
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
|
||||||
|
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
|
||||||
preempted_reqs.append(preempted_req)
|
preempted_reqs.append(preempted_req)
|
||||||
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
|
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
|
||||||
main_process_metrics.num_requests_waiting.inc(1)
|
main_process_metrics.num_requests_waiting.inc(1)
|
||||||
|
@@ -16,12 +16,12 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from filelock import FileLock
|
||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.config import ModelConfig
|
from fastdeploy.config import ModelConfig
|
||||||
@@ -132,7 +132,7 @@ class EngineClient:
|
|||||||
pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50))
|
pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50))
|
||||||
)
|
)
|
||||||
self.connection_initialized = False
|
self.connection_initialized = False
|
||||||
self.clear_update_lock = threading.Lock()
|
self.clear_update_lock = FileLock(f"/tmp/fd_weight_clear_update_lock__pid{pid}_port{port}.lock")
|
||||||
|
|
||||||
def create_zmq_client(self, model, mode):
|
def create_zmq_client(self, model, mode):
|
||||||
"""
|
"""
|
||||||
@@ -351,7 +351,9 @@ class EngineClient:
|
|||||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
|
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
|
||||||
return True, ""
|
return True, ""
|
||||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
|
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
|
||||||
return False, "updating model weight already"
|
return False, "worker is updating model weight already"
|
||||||
|
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
|
||||||
|
return False, "worker is clearing model weight, cannot update now"
|
||||||
|
|
||||||
self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING
|
self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING
|
||||||
if self.enable_prefix_caching or self.enable_splitwise:
|
if self.enable_prefix_caching or self.enable_splitwise:
|
||||||
@@ -395,7 +397,9 @@ class EngineClient:
|
|||||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED:
|
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED:
|
||||||
return True, ""
|
return True, ""
|
||||||
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
|
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
|
||||||
return False, "clearing model weight already"
|
return False, "worker is clearing model weight already"
|
||||||
|
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
|
||||||
|
return False, "worker is updating model weight, cannot clear now"
|
||||||
|
|
||||||
self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING
|
self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING
|
||||||
if self.enable_prefix_caching or self.enable_splitwise:
|
if self.enable_prefix_caching or self.enable_splitwise:
|
||||||
|
@@ -297,7 +297,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if layer.reduce_results and layer.tp_size > 1:
|
if layer.reduce_results and layer.tp_size > 1:
|
||||||
tensor_model_parallel_all_reduce(fused_moe_out)
|
tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group)
|
||||||
|
|
||||||
return fused_moe_out
|
return fused_moe_out
|
||||||
|
|
||||||
|
@@ -220,23 +220,17 @@ class DynamicWeightManager:
|
|||||||
check model weights status
|
check model weights status
|
||||||
"""
|
"""
|
||||||
logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}")
|
logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}")
|
||||||
is_stop = 0
|
|
||||||
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
|
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
|
||||||
if model_weights_status.value[0] == ModelWeightsStatus.UPDATING:
|
if model_weights_status.value[0] == ModelWeightsStatus.UPDATING:
|
||||||
logger.info("infer engine stopped! start to load new checkpoint...")
|
logger.info("infer engine stopped! start to load new checkpoint...")
|
||||||
model_runner.update_parameters(pid)
|
model_runner.update_parameters(pid)
|
||||||
|
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
|
||||||
|
time.sleep(0.01)
|
||||||
|
logger.info("finished loading new checkpoint")
|
||||||
elif model_weights_status.value[0] == ModelWeightsStatus.CLEARING:
|
elif model_weights_status.value[0] == ModelWeightsStatus.CLEARING:
|
||||||
logger.info("infer engine stopped! start to clear checkpoint...")
|
logger.info("infer engine stopped! start to clear checkpoint...")
|
||||||
model_runner.clear_parameters(pid)
|
model_runner.clear_parameters(pid)
|
||||||
while True:
|
while model_weights_status.value[0] != ModelWeightsStatus.CLEARED:
|
||||||
if model_weights_status.value[0] == ModelWeightsStatus.NORMAL:
|
time.sleep(0.01)
|
||||||
logger.info("finished loading new checkpoint")
|
|
||||||
break
|
|
||||||
elif is_stop == 1 or (model_weights_status.value[0] == ModelWeightsStatus.CLEARED and is_stop == 0):
|
|
||||||
if is_stop == 0:
|
|
||||||
logger.info("finished clearing checkpoint")
|
logger.info("finished clearing checkpoint")
|
||||||
is_stop = 1
|
time.sleep(0.01)
|
||||||
time.sleep(0.001)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
time.sleep(0.001)
|
|
||||||
|
@@ -270,6 +270,11 @@ class PaddleDisWorkerProc:
|
|||||||
create=False,
|
create=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _broadcast_model_weights_signal(self, src: int, group) -> int:
|
||||||
|
model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32")
|
||||||
|
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
|
||||||
|
return model_weights_signal_tensor.item()
|
||||||
|
|
||||||
def event_loop_normal(self) -> None:
|
def event_loop_normal(self) -> None:
|
||||||
"""Main event loop for Paddle Distrubuted Workers.
|
"""Main event loop for Paddle Distrubuted Workers.
|
||||||
TODO(gongshaotian): support remote calling of functions that control worker.
|
TODO(gongshaotian): support remote calling of functions that control worker.
|
||||||
@@ -279,15 +284,19 @@ class PaddleDisWorkerProc:
|
|||||||
req_ids = []
|
req_ids = []
|
||||||
num_running_requests = 0
|
num_running_requests = 0
|
||||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||||
self.model_weights_signal = paddle.zeros([1], dtype=paddle.int32)
|
self.model_weights_signal = np.zeros([1], dtype=np.int32)
|
||||||
while True:
|
while True:
|
||||||
if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
|
if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
|
||||||
if self.model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
|
if self.model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
|
||||||
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
|
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
|
||||||
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
|
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
|
||||||
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.ep_group)
|
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
|
||||||
if self.fd_config.load_config.dynamic_load_weight:
|
src=0, group=self.parallel_config.ep_group
|
||||||
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.tp_group)
|
)
|
||||||
|
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.tensor_parallel_size > 1:
|
||||||
|
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
|
||||||
|
src=0, group=self.parallel_config.tp_group
|
||||||
|
)
|
||||||
|
|
||||||
self.insert_step = False
|
self.insert_step = False
|
||||||
req_dicts = None
|
req_dicts = None
|
||||||
@@ -315,7 +324,9 @@ class PaddleDisWorkerProc:
|
|||||||
else:
|
else:
|
||||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||||
if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
|
if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
|
||||||
logger.info(f"Rank: {self.local_rank} has updated parameters.")
|
logger.info(
|
||||||
|
f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"
|
||||||
|
)
|
||||||
from fastdeploy.rl.dynamic_weight_manager import (
|
from fastdeploy.rl.dynamic_weight_manager import (
|
||||||
DynamicWeightManager,
|
DynamicWeightManager,
|
||||||
)
|
)
|
||||||
@@ -327,6 +338,7 @@ class PaddleDisWorkerProc:
|
|||||||
self.parallel_config.engine_worker_queue_port,
|
self.parallel_config.engine_worker_queue_port,
|
||||||
)
|
)
|
||||||
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
|
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
|
||||||
|
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
|
||||||
|
|
||||||
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
|
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
|
||||||
logger.info(f"Rank: {self.local_rank} Detected new requests.")
|
logger.info(f"Rank: {self.local_rank} Detected new requests.")
|
||||||
|
Reference in New Issue
Block a user