[Bug fix] Fix zmq core bug (#3357)

* [Bug fix] Fix zmq core bug due to concurrently used by threads

* Fix zmq core bug due to concurrently used by threads
This commit is contained in:
chenjian
2025-08-13 20:24:39 +08:00
committed by GitHub
parent 7573802a88
commit 89177d881c
3 changed files with 20 additions and 17 deletions

View File

@@ -34,6 +34,7 @@ class InternalAdapter:
self.engine = engine
self.dp_rank = dp_rank
recv_control_cmd_ports = envs.FD_ZMQ_CONTROL_CMD_SERVER_PORTS.split(",")
self.response_lock = threading.Lock() # prevent to call send_multipart in zmq concurrently
self.recv_control_cmd_server = ZmqTcpServer(port=recv_control_cmd_ports[dp_rank], mode=zmq.ROUTER)
self.recv_external_instruct_thread = threading.Thread(
target=self._recv_external_module_control_instruct, daemon=True
@@ -43,7 +44,6 @@ class InternalAdapter:
target=self._response_external_module_control_instruct, daemon=True
)
self.response_external_instruct_thread.start()
self.response_lock = threading.Lock() # prevent to call send_multipart in zmq concurrently
def _get_current_server_info(self):
"""
@@ -71,13 +71,17 @@ class InternalAdapter:
"""
while True:
try:
task = self.recv_control_cmd_server.recv_control_cmd()
with self.response_lock:
task = self.recv_control_cmd_server.recv_control_cmd()
if task is None:
time.sleep(0.001)
continue
logger.info(f"Recieve control task: {task}")
task_id_str = task["task_id"]
if task["cmd"] == "get_payload":
payload_info = self._get_current_server_info()
result = {"task_id": task_id_str, "result": payload_info}
logger.info(f"Response for task: {task_id_str}")
logger.debug(f"Response for task: {task_id_str}")
with self.response_lock:
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
@@ -87,7 +91,7 @@ class InternalAdapter:
extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=1),
)
result = {"task_id": task_id_str, "result": metrics_text}
logger.info(f"Response for task: {task_id_str}")
logger.debug(f"Response for task: {task_id_str}")
with self.response_lock:
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
elif task["cmd"] == "connect_rdma":