[fix] fix ep group all-reduce (#4140)

* [fix] fix ep group all-reduce

* [fix] fix clear/update lock not working when workers > 1

* [chore] add preemption triggered info log

* [fix] fix code style

* fix model_weights_signal (#4092)

* fix model_weights_signal

---------

Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
李泳桦
2025-09-18 10:34:49 +08:00
committed by GitHub
parent cffde70949
commit 0fa28b1068
6 changed files with 41 additions and 26 deletions

View File

@@ -352,8 +352,12 @@ class ParallelConfig:
) )
dist.collective._set_custom_gid(None) dist.collective._set_custom_gid(None)
# same ep group id # same ep group id
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset) # dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
self.ep_group = dist.new_group(range(self.expert_parallel_size)) # self.ep_group = dist.new_group(range(self.expert_parallel_size))
if self.enable_expert_parallel:
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
self.ep_group = dist.new_group(range(self.expert_parallel_size))
dist.collective._set_custom_gid(None)
logger.info( logger.info(
f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}." f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}."
) )
@@ -1339,7 +1343,7 @@ class FDConfig:
) )
if self.scheduler_config is not None: if self.scheduler_config is not None:
self.scheduler_config.check() self.scheduler_config.check()
if int(envs.ENABLE_V1_KVCACHE_SCHEDULER) == 1: if int(envs.ENABLE_V1_KVCACHE_SCHEDULER) == 1:
assert ( assert (
int(envs.FD_DISABLED_RECOVER) == 0 int(envs.FD_DISABLED_RECOVER) == 0

View File

@@ -120,6 +120,7 @@ class ResourceManagerV1(ResourceManager):
self._free_blocks(preempted_req) self._free_blocks(preempted_req)
preempted_req.cached_block_num = 0 preempted_req.cached_block_num = 0
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id) self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
preempted_reqs.append(preempted_req) preempted_reqs.append(preempted_req)
scheduled_reqs.append(self._prepare_preempt_task(preempted_req)) scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
main_process_metrics.num_requests_waiting.inc(1) main_process_metrics.num_requests_waiting.inc(1)

View File

@@ -16,12 +16,12 @@
import inspect import inspect
import os import os
import threading
import time import time
import traceback import traceback
import uuid import uuid
import numpy as np import numpy as np
from filelock import FileLock
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.config import ModelConfig from fastdeploy.config import ModelConfig
@@ -132,7 +132,7 @@ class EngineClient:
pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50)) pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50))
) )
self.connection_initialized = False self.connection_initialized = False
self.clear_update_lock = threading.Lock() self.clear_update_lock = FileLock(f"/tmp/fd_weight_clear_update_lock__pid{pid}_port{port}.lock")
def create_zmq_client(self, model, mode): def create_zmq_client(self, model, mode):
""" """
@@ -351,7 +351,9 @@ class EngineClient:
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL: if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
return True, "" return True, ""
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING: if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
return False, "updating model weight already" return False, "worker is updating model weight already"
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
return False, "worker is clearing model weight, cannot update now"
self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING
if self.enable_prefix_caching or self.enable_splitwise: if self.enable_prefix_caching or self.enable_splitwise:
@@ -395,7 +397,9 @@ class EngineClient:
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED: if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED:
return True, "" return True, ""
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING: if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
return False, "clearing model weight already" return False, "worker is clearing model weight already"
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
return False, "worker is updating model weight, cannot clear now"
self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING
if self.enable_prefix_caching or self.enable_splitwise: if self.enable_prefix_caching or self.enable_splitwise:

View File

@@ -297,7 +297,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
) )
if layer.reduce_results and layer.tp_size > 1: if layer.reduce_results and layer.tp_size > 1:
tensor_model_parallel_all_reduce(fused_moe_out) tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group)
return fused_moe_out return fused_moe_out

View File

@@ -220,23 +220,17 @@ class DynamicWeightManager:
check model weights status check model weights status
""" """
logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}") logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}")
is_stop = 0
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL: while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
if model_weights_status.value[0] == ModelWeightsStatus.UPDATING: if model_weights_status.value[0] == ModelWeightsStatus.UPDATING:
logger.info("infer engine stopped! start to load new checkpoint...") logger.info("infer engine stopped! start to load new checkpoint...")
model_runner.update_parameters(pid) model_runner.update_parameters(pid)
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
time.sleep(0.01)
logger.info("finished loading new checkpoint")
elif model_weights_status.value[0] == ModelWeightsStatus.CLEARING: elif model_weights_status.value[0] == ModelWeightsStatus.CLEARING:
logger.info("infer engine stopped! start to clear checkpoint...") logger.info("infer engine stopped! start to clear checkpoint...")
model_runner.clear_parameters(pid) model_runner.clear_parameters(pid)
while True: while model_weights_status.value[0] != ModelWeightsStatus.CLEARED:
if model_weights_status.value[0] == ModelWeightsStatus.NORMAL: time.sleep(0.01)
logger.info("finished loading new checkpoint") logger.info("finished clearing checkpoint")
break time.sleep(0.01)
elif is_stop == 1 or (model_weights_status.value[0] == ModelWeightsStatus.CLEARED and is_stop == 0):
if is_stop == 0:
logger.info("finished clearing checkpoint")
is_stop = 1
time.sleep(0.001)
break
else:
time.sleep(0.001)

View File

@@ -270,6 +270,11 @@ class PaddleDisWorkerProc:
create=False, create=False,
) )
def _broadcast_model_weights_signal(self, src: int, group) -> int:
model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32")
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
return model_weights_signal_tensor.item()
def event_loop_normal(self) -> None: def event_loop_normal(self) -> None:
"""Main event loop for Paddle Distrubuted Workers. """Main event loop for Paddle Distrubuted Workers.
TODO(gongshaotian): support remote calling of functions that control worker. TODO(gongshaotian): support remote calling of functions that control worker.
@@ -279,15 +284,19 @@ class PaddleDisWorkerProc:
req_ids = [] req_ids = []
num_running_requests = 0 num_running_requests = 0
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
self.model_weights_signal = paddle.zeros([1], dtype=paddle.int32) self.model_weights_signal = np.zeros([1], dtype=np.int32)
while True: while True:
if self.local_rank % self.parallel_config.tensor_parallel_size == 0: if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
if self.model_weights_status.value[0] != ModelWeightsStatus.NORMAL: if self.model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
self.model_weights_signal[0] = int(self.model_weights_status.value[0]) self.model_weights_signal[0] = int(self.model_weights_status.value[0])
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel: if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.ep_group) self.model_weights_signal[0] = self._broadcast_model_weights_signal(
if self.fd_config.load_config.dynamic_load_weight: src=0, group=self.parallel_config.ep_group
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.tp_group) )
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.tensor_parallel_size > 1:
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
src=0, group=self.parallel_config.tp_group
)
self.insert_step = False self.insert_step = False
req_dicts = None req_dicts = None
@@ -315,7 +324,9 @@ class PaddleDisWorkerProc:
else: else:
paddle.distributed.barrier(self.parallel_config.tp_group) paddle.distributed.barrier(self.parallel_config.tp_group)
if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL: if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
logger.info(f"Rank: {self.local_rank} has updated parameters.") logger.info(
f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"
)
from fastdeploy.rl.dynamic_weight_manager import ( from fastdeploy.rl.dynamic_weight_manager import (
DynamicWeightManager, DynamicWeightManager,
) )
@@ -327,6 +338,7 @@ class PaddleDisWorkerProc:
self.parallel_config.engine_worker_queue_port, self.parallel_config.engine_worker_queue_port,
) )
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1: if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
logger.info(f"Rank: {self.local_rank} Detected new requests.") logger.info(f"Rank: {self.local_rank} Detected new requests.")