Compare commits

...

5 Commits

Author SHA1 Message Date
Nicolas Mowen
7a02a448cb Add locking 2025-09-22 13:59:30 -06:00
Nicolas Mowen
4ab8de91a9 Move ZMQ detector to onnx runner 2025-09-22 12:55:36 -06:00
Nicolas Mowen
fbcf64d7bd Get working 2025-09-22 12:32:11 -06:00
Nicolas Mowen
de960285f6 Get correct ports running 2025-09-22 12:00:45 -06:00
Nicolas Mowen
8b78c85bda Implement broker/dealer router 2025-09-22 11:14:53 -06:00
5 changed files with 1117 additions and 199 deletions

View File

@@ -28,6 +28,9 @@ from frigate.comms.object_detector_signaler import DetectorProxy
from frigate.comms.webpush import WebPushClient
from frigate.comms.ws import WebSocketClient
from frigate.comms.zmq_proxy import ZmqProxy
from frigate.comms.zmq_req_router_broker import (
ZmqReqRouterBroker,
)
from frigate.config.camera.updater import CameraConfigUpdatePublisher
from frigate.config.config import FrigateConfig
from frigate.const import (
@@ -307,6 +310,14 @@ class FrigateApp:
self.event_metadata_updater = EventMetadataPublisher()
self.inter_zmq_proxy = ZmqProxy()
self.detection_proxy = DetectorProxy()
self.zmq_router_broker: ZmqReqRouterBroker | None = None
zmq_detectors = [
det for det in self.config.detectors.values() if det.type == "zmq"
]
if any(zmq_detectors):
backend_endpoint = zmq_detectors[0].endpoint
self.zmq_router_broker = ZmqReqRouterBroker(backend_endpoint)
def init_onvif(self) -> None:
self.onvif_controller = OnvifController(self.config, self.ptz_metrics)
@@ -644,6 +655,9 @@ class FrigateApp:
self.inter_zmq_proxy.stop()
self.detection_proxy.stop()
if self.zmq_router_broker:
self.zmq_router_broker.stop()
while len(self.detection_shms) > 0:
shm = self.detection_shms.pop()
shm.close()

View File

@@ -0,0 +1,61 @@
"""ZMQ REQ/ROUTER front-end to DEALER/REP back-end broker.
This module provides a small proxy that:
- Binds a ROUTER socket on a fixed local endpoint for REQ clients
- Connects a DEALER socket to the user-configured backend endpoint (REP servers)
Pattern: REQ -> ROUTER === proxy === DEALER -> REP
The goal is to allow multiple REQ clients and/or multiple backend workers
to share a single configured connection, enabling multiple models/runners
behind the same broker while keeping local clients stable via constants.
"""
from __future__ import annotations
import threading
import zmq
REQ_ROUTER_ENDPOINT = "ipc:///tmp/cache/zmq_detector_router"
class _RouterDealerRunner(threading.Thread):
def __init__(self, context: zmq.Context[zmq.Socket], backend_endpoint: str) -> None:
super().__init__(name="zmq_router_dealer_broker", daemon=True)
self.context = context
self.backend_endpoint = backend_endpoint
def run(self) -> None:
frontend = self.context.socket(zmq.ROUTER)
frontend.bind(REQ_ROUTER_ENDPOINT)
backend = self.context.socket(zmq.DEALER)
backend.bind(self.backend_endpoint)
try:
zmq.proxy(frontend, backend)
except zmq.ZMQError:
# Unblocked when context is destroyed in the controller
pass
class ZmqReqRouterBroker:
"""Starts a ROUTER/DEALER proxy bridging local REQ clients to backend REP.
- ROUTER binds to REQ_ROUTER_ENDPOINT (constant, local)
- DEALER connects to the provided backend_endpoint (user-configured)
"""
def __init__(self, backend_endpoint: str) -> None:
self.backend_endpoint = backend_endpoint
self.context = zmq.Context()
self.runner = _RouterDealerRunner(self.context, backend_endpoint)
self.runner.start()
def stop(self) -> None:
# Destroying the context signals the proxy to stop
try:
self.context.destroy()
finally:
self.runner.join()

View File

@@ -1,13 +1,17 @@
"""Base runner implementation for ONNX models."""
import json
import logging
import os
import threading
from abc import ABC, abstractmethod
from typing import Any
import numpy as np
import onnxruntime as ort
import zmq
from frigate.comms.zmq_req_router_broker import REQ_ROUTER_ENDPOINT
from frigate.util.model import get_ort_providers
from frigate.util.rknn_converter import auto_convert_model, is_rknn_compatible
@@ -301,6 +305,187 @@ class OpenVINOModelRunner(BaseModelRunner):
return outputs
class ZmqIpcRunner(BaseModelRunner):
"""Runner that forwards inference over ZMQ REQ/ROUTER to backend workers.
This allows reusing the same interface as local runners while delegating
inference to the external ZMQ workers.
"""
def __init__(
self,
model_path: str,
model_type: str,
request_timeout_ms: int = 200,
linger_ms: int = 0,
endpoint: str = REQ_ROUTER_ENDPOINT,
):
self.model_type = model_type
self.model_path = model_path
self.model_name = os.path.basename(model_path)
self._endpoint = endpoint
self._context = zmq.Context()
self._socket = self._context.socket(zmq.REQ)
self._socket.setsockopt(zmq.RCVTIMEO, request_timeout_ms)
self._socket.setsockopt(zmq.SNDTIMEO, request_timeout_ms)
self._socket.setsockopt(zmq.LINGER, linger_ms)
self._socket.connect(self._endpoint)
self._model_ready = False
self._io_lock = threading.Lock()
@staticmethod
def is_complex_model(model_type: str) -> bool:
# Import here to avoid circular imports
from frigate.detectors.detector_config import ModelTypeEnum
from frigate.embeddings.types import EnrichmentModelTypeEnum
return model_type in [
ModelTypeEnum.yolonas.value,
EnrichmentModelTypeEnum.paddleocr.value,
EnrichmentModelTypeEnum.jina_v1.value,
EnrichmentModelTypeEnum.jina_v2.value,
]
def get_input_names(self) -> list[str]:
if "vision" in self.model_name:
return ["pixel_values"]
elif "arcface" in self.model_name:
return ["data"]
else:
return ["input"]
def get_input_width(self) -> int:
# Not known/required for ZMQ forwarding
return -1
def _build_header(self, tensor_input: np.ndarray) -> bytes:
header: dict[str, object] = {
"shape": list(tensor_input.shape),
"dtype": str(tensor_input.dtype.name),
"model_type": self.model_type,
"model_name": self.model_name,
}
return json.dumps(header).encode("utf-8")
def _decode_response(self, frames: list[bytes]) -> np.ndarray:
if len(frames) == 1:
buf = frames[0]
if len(buf) != 20 * 6 * 4:
raise ValueError(f"Unexpected payload size: {len(buf)}")
return np.frombuffer(buf, dtype=np.float32).reshape((20, 6))
if len(frames) >= 2:
header = json.loads(frames[0].decode("utf-8"))
shape = tuple(header.get("shape", []))
dtype = np.dtype(header.get("dtype", "float32"))
return np.frombuffer(frames[1], dtype=dtype).reshape(shape)
raise ValueError("Empty or malformed reply from ZMQ detector")
def run(self, input: dict[str, np.ndarray]) -> np.ndarray | None:
if not self._model_ready:
if not self.ensure_model_ready(self.model_path):
raise TimeoutError("ZMQ detector model is not ready after transfer")
self._model_ready = True
input_name = next(iter(input))
tensor = input[input_name]
header = self._build_header(tensor)
payload = memoryview(tensor.tobytes(order="C"))
try:
with self._io_lock:
self._socket.send_multipart([header, payload])
frames = self._socket.recv_multipart()
except zmq.Again as e:
raise TimeoutError("ZMQ detector request timed out") from e
except zmq.ZMQError as e:
raise RuntimeError(f"ZMQ error: {e}") from e
return self._decode_response(frames)
def ensure_model_ready(self, model_path: str) -> bool:
"""Ensure the remote has the model and it is loaded.
1) Send model_request with model_name
2) If not available, send model_data with the file contents
3) Wait for loaded confirmation
Returns True on success.
"""
# Check model availability
req = {"model_request": True, "model_name": self.model_name}
with self._io_lock:
self._socket.send_multipart([json.dumps(req).encode("utf-8")])
# Temporarily extend timeout for model ops
original_rcv = self._socket.getsockopt(zmq.RCVTIMEO)
try:
self._socket.setsockopt(zmq.RCVTIMEO, max(30000, int(original_rcv or 0)))
resp_frames = self._socket.recv_multipart()
except zmq.Again:
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv)
return False
finally:
try:
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv)
except Exception:
pass
try:
if len(resp_frames) != 1:
return False
resp = json.loads(resp_frames[0].decode("utf-8"))
if resp.get("model_available") and resp.get("model_loaded"):
logger.info(f"ZMQ detector model {self.model_name} is ready")
return True
except Exception:
return False
try:
with open(model_path, "rb") as f:
model_bytes = f.read()
except Exception:
return False
header = {"model_data": True, "model_name": self.model_name}
with self._io_lock:
self._socket.send_multipart(
[json.dumps(header).encode("utf-8"), model_bytes]
)
original_rcv2 = self._socket.getsockopt(zmq.RCVTIMEO)
try:
self._socket.setsockopt(zmq.RCVTIMEO, max(30000, int(original_rcv2 or 0)))
resp2 = self._socket.recv_multipart()
except zmq.Again:
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv2)
return False
finally:
try:
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv2)
except Exception:
pass
try:
if len(resp2) != 1:
return False
j = json.loads(resp2[0].decode("utf-8"))
return bool(j.get("model_saved") and j.get("model_loaded"))
except Exception:
return False
def __del__(self) -> None:
try:
if self._socket is not None:
self._socket.close()
except Exception:
pass
try:
if self._context is not None:
self._context.term()
except Exception:
pass
class RKNNModelRunner(BaseModelRunner):
"""Run RKNN models for embeddings."""
@@ -415,6 +600,10 @@ def get_optimized_runner(
if rknn_path:
return RKNNModelRunner(rknn_path)
if device == "ZMQ" and not ZmqIpcRunner.is_complex_model(model_type):
logger.info(f"Using ZMQ detector model {model_path}")
return ZmqIpcRunner(model_path, model_type, **kwargs)
providers, options = get_ort_providers(device == "CPU", device, **kwargs)
if providers[0] == "CPUExecutionProvider":

View File

@@ -8,7 +8,9 @@ import zmq
from pydantic import Field
from typing_extensions import Literal
from frigate.comms.zmq_req_router_broker import REQ_ROUTER_ENDPOINT
from frigate.detectors.detection_api import DetectionApi
from frigate.detectors.detection_runners import ZmqIpcRunner
from frigate.detectors.detector_config import BaseDetectorConfig
logger = logging.getLogger(__name__)
@@ -49,9 +51,7 @@ class ZmqIpcDetector(DetectionApi):
On any error or timeout, this detector returns a zero array of shape (20, 6).
Model Management:
- On initialization, sends model request to check if model is available
- If model not available, sends model data via ZMQ
- Only starts inference after model is ready
- Model transfer/availability is handled by the runner automatically
"""
type_key = DETECTOR_KEY
@@ -60,21 +60,23 @@ class ZmqIpcDetector(DetectionApi):
super().__init__(detector_config)
self._context = zmq.Context()
self._endpoint = detector_config.endpoint
self._endpoint = REQ_ROUTER_ENDPOINT
self._request_timeout_ms = detector_config.request_timeout_ms
self._linger_ms = detector_config.linger_ms
self._socket = None
self._create_socket()
# Model management
self._model_ready = False
self._model_name = self._get_model_name()
# Initialize model if needed
self._initialize_model()
# Preallocate zero result for error paths
self._zero_result = np.zeros((20, 6), np.float32)
self._runner = ZmqIpcRunner(
model_path=self.detector_config.model.path,
model_type=str(self.detector_config.model.model_type.value),
request_timeout_ms=self._request_timeout_ms,
linger_ms=self._linger_ms,
endpoint=self._endpoint,
)
def _create_socket(self) -> None:
if self._socket is not None:
@@ -96,167 +98,12 @@ class ZmqIpcDetector(DetectionApi):
model_path = self.detector_config.model.path
return os.path.basename(model_path)
def _initialize_model(self) -> None:
"""Initialize the model by checking availability and transferring if needed."""
try:
logger.info(f"Initializing model: {self._model_name}")
# Check if model is available and transfer if needed
if self._check_and_transfer_model():
logger.info(f"Model {self._model_name} is ready")
self._model_ready = True
else:
logger.error(f"Failed to initialize model {self._model_name}")
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
def _check_and_transfer_model(self) -> bool:
"""Check if model is available and transfer if needed in one atomic operation."""
try:
# Send model availability request
header = {"model_request": True, "model_name": self._model_name}
header_bytes = json.dumps(header).encode("utf-8")
self._socket.send_multipart([header_bytes])
# Temporarily increase timeout for model operations
original_timeout = self._socket.getsockopt(zmq.RCVTIMEO)
self._socket.setsockopt(zmq.RCVTIMEO, 30000)
try:
response_frames = self._socket.recv_multipart()
finally:
self._socket.setsockopt(zmq.RCVTIMEO, original_timeout)
if len(response_frames) == 1:
try:
response = json.loads(response_frames[0].decode("utf-8"))
model_available = response.get("model_available", False)
model_loaded = response.get("model_loaded", False)
if model_available and model_loaded:
return True
elif model_available and not model_loaded:
logger.error("Model exists but failed to load")
return False
else:
return self._send_model_data()
except json.JSONDecodeError:
logger.warning(
"Received non-JSON response for model availability check"
)
return False
else:
logger.warning(
"Received unexpected response format for model availability check"
)
return False
except Exception as e:
logger.error(f"Failed to check and transfer model: {e}")
return False
def _check_model_availability(self) -> bool:
"""Check if the model is available on the detector."""
try:
# Send model availability request
header = {"model_request": True, "model_name": self._model_name}
header_bytes = json.dumps(header).encode("utf-8")
self._socket.send_multipart([header_bytes])
# Receive response
response_frames = self._socket.recv_multipart()
# Check if this is a JSON response (model management)
if len(response_frames) == 1:
try:
response = json.loads(response_frames[0].decode("utf-8"))
model_available = response.get("model_available", False)
model_loaded = response.get("model_loaded", False)
logger.debug(
f"Model availability check: available={model_available}, loaded={model_loaded}"
)
return model_available and model_loaded
except json.JSONDecodeError:
logger.warning(
"Received non-JSON response for model availability check"
)
return False
else:
logger.warning(
"Received unexpected response format for model availability check"
)
return False
except Exception as e:
logger.error(f"Failed to check model availability: {e}")
return False
def _send_model_data(self) -> bool:
"""Send model data to the detector."""
try:
model_path = self.detector_config.model.path
if not os.path.exists(model_path):
logger.error(f"Model file not found: {model_path}")
return False
logger.info(f"Transferring model to detector: {self._model_name}")
with open(model_path, "rb") as f:
model_data = f.read()
header = {"model_data": True, "model_name": self._model_name}
header_bytes = json.dumps(header).encode("utf-8")
self._socket.send_multipart([header_bytes, model_data])
# Temporarily increase timeout for model loading (can take several seconds)
original_timeout = self._socket.getsockopt(zmq.RCVTIMEO)
self._socket.setsockopt(zmq.RCVTIMEO, 30000)
try:
# Receive response
response_frames = self._socket.recv_multipart()
finally:
# Restore original timeout
self._socket.setsockopt(zmq.RCVTIMEO, original_timeout)
# Check if this is a JSON response (model management)
if len(response_frames) == 1:
try:
response = json.loads(response_frames[0].decode("utf-8"))
model_saved = response.get("model_saved", False)
model_loaded = response.get("model_loaded", False)
if model_saved and model_loaded:
logger.info(
f"Model {self._model_name} transferred and loaded successfully"
)
else:
logger.error(
f"Model transfer failed: saved={model_saved}, loaded={model_loaded}"
)
return model_saved and model_loaded
except json.JSONDecodeError:
logger.warning("Received non-JSON response for model data transfer")
return False
else:
logger.warning(
"Received unexpected response format for model data transfer"
)
return False
except Exception as e:
logger.error(f"Failed to send model data: {e}")
return False
def _build_header(self, tensor_input: np.ndarray) -> bytes:
header: dict[str, Any] = {
"shape": list(tensor_input.shape),
"dtype": str(tensor_input.dtype.name),
"model_type": str(self.detector_config.model.model_type.name),
"model_name": self._model_name,
}
return json.dumps(header).encode("utf-8")
@@ -285,42 +132,11 @@ class ZmqIpcDetector(DetectionApi):
return self._zero_result
def detect_raw(self, tensor_input: np.ndarray) -> np.ndarray:
if not self._model_ready:
logger.warning("Model not ready, returning zero detections")
return self._zero_result
try:
header_bytes = self._build_header(tensor_input)
payload_bytes = memoryview(tensor_input.tobytes(order="C"))
# Send request
self._socket.send_multipart([header_bytes, payload_bytes])
# Receive reply
reply_frames = self._socket.recv_multipart()
detections = self._decode_response(reply_frames)
# Ensure output shape and dtype are exactly as expected
return detections
except zmq.Again:
# Timeout
logger.debug("ZMQ detector request timed out; resetting socket")
try:
self._create_socket()
self._initialize_model()
except Exception:
pass
return self._zero_result
except zmq.ZMQError as exc:
logger.error(f"ZMQ detector ZMQError: {exc}; resetting socket")
try:
self._create_socket()
self._initialize_model()
except Exception:
pass
return self._zero_result
result = self._runner.run({"input": tensor_input})
return result if isinstance(result, np.ndarray) else self._zero_result
except Exception as exc: # noqa: BLE001
logger.error(f"ZMQ detector unexpected error: {exc}")
logger.error(f"ZMQ IPC runner error: {exc}")
return self._zero_result
def __del__(self) -> None: # pragma: no cover - best-effort cleanup

View File

@@ -0,0 +1,838 @@
#!/usr/bin/env python3
"""
ZMQ TCP ONNX Runtime Client
This client connects to the ZMQ TCP proxy, accepts tensor inputs,
runs inference via ONNX Runtime, and returns detection results.
Protocol:
- Receives multipart messages: [header_json_bytes, tensor_bytes]
- Header contains shape and dtype information
- Runs ONNX inference on the tensor
- Returns results in the expected format: [20, 6] float32 array
Note: Timeouts are normal when Frigate has no motion to detect.
The server will continue running and waiting for requests.
"""
import json
import logging
import os
import threading
import time
from typing import Dict, List, Optional, Tuple
import numpy as np
import onnxruntime as ort
import zmq
from model_util import post_process_dfine, post_process_rfdetr, post_process_yolo
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
class ZmqOnnxWorker(threading.Thread):
"""
A worker thread that connects a REP socket to the endpoint and processes
requests using a shared model session map. This mirrors the single-runner
logic, but the ONNX Runtime session is fetched from the shared map, and
created on-demand if missing.
"""
def __init__(
self,
worker_id: int,
context: zmq.Context,
endpoint: str,
models_dir: str,
model_sessions: Dict[str, ort.InferenceSession],
model_lock: threading.Lock,
providers: Optional[List[str]],
zero_result: np.ndarray,
) -> None:
super().__init__(name=f"onnx_worker_{worker_id}", daemon=True)
self.worker_id = worker_id
self.context = context
self.endpoint = self._normalize_endpoint(endpoint)
self.models_dir = models_dir
self.model_sessions = model_sessions
self.model_lock = model_lock
self.providers = providers
self.zero_result = zero_result
self.socket: Optional[zmq.Socket] = None
def _normalize_endpoint(self, endpoint: str) -> str:
if endpoint.startswith("tcp://*:"):
port = endpoint.split(":", 2)[-1]
return f"tcp://127.0.0.1:{port}"
return endpoint
# --- ZMQ helpers ---
def _create_socket(self) -> zmq.Socket:
sock = self.context.socket(zmq.REP)
sock.setsockopt(zmq.RCVTIMEO, 5000)
sock.setsockopt(zmq.SNDTIMEO, 5000)
sock.setsockopt(zmq.LINGER, 0)
sock.connect(self.endpoint)
return sock
def _decode_request(self, frames: List[bytes]) -> Tuple[Optional[np.ndarray], dict]:
if len(frames) < 1:
raise ValueError(f"Expected at least 1 frame, got {len(frames)}")
header_bytes = frames[0]
header = json.loads(header_bytes.decode("utf-8"))
if "model_request" in header:
return None, header
if "model_data" in header:
return None, header
if len(frames) < 2:
raise ValueError(f"Tensor request expected 2 frames, got {len(frames)}")
tensor_bytes = frames[1]
shape = tuple(header.get("shape", []))
dtype_str = header.get("dtype", "uint8")
dtype = np.dtype(dtype_str)
tensor = np.frombuffer(tensor_bytes, dtype=dtype).reshape(shape)
return tensor, header
def _build_response(self, result: np.ndarray) -> List[bytes]:
header = {
"shape": list(result.shape),
"dtype": str(result.dtype.name),
"timestamp": time.time(),
}
return [json.dumps(header).encode("utf-8"), result.tobytes(order="C")]
def _build_error_response(self, error_msg: str) -> List[bytes]:
error_header = {"shape": [20, 6], "dtype": "float32", "error": error_msg}
return [
json.dumps(error_header).encode("utf-8"),
self.zero_result.tobytes(order="C"),
]
# --- Model/session helpers ---
def _check_model_exists(self, model_name: str) -> bool:
return os.path.exists(os.path.join(self.models_dir, model_name))
def _save_model(self, model_name: str, model_data: bytes) -> bool:
try:
os.makedirs(self.models_dir, exist_ok=True)
with open(os.path.join(self.models_dir, model_name), "wb") as f:
f.write(model_data)
return True
except Exception as e:
logger.error(
f"Worker {self.worker_id} failed to save model {model_name}: {e}"
)
return False
def _get_or_create_session(self, model_name: str) -> Optional[ort.InferenceSession]:
with self.model_lock:
session = self.model_sessions.get(model_name)
if session is not None:
return session
try:
providers = self.providers or ["CoreMLExecutionProvider"]
session = ort.InferenceSession(
os.path.join(self.models_dir, model_name), providers=providers
)
self.model_sessions[model_name] = session
return session
except Exception as e:
logger.error(
f"Worker {self.worker_id} failed to load model {model_name}: {e}"
)
return None
# --- Inference helpers ---
def _extract_input_hw(self, header: dict) -> Tuple[int, int]:
try:
if "width" in header and "height" in header:
return int(header["width"]), int(header["height"])
shape = tuple(header.get("shape", []))
layout = header.get("layout") or header.get("order")
if layout and shape:
layout = str(layout).upper()
if len(shape) == 4:
if layout == "NCHW":
return int(shape[3]), int(shape[2])
if layout == "NHWC":
return int(shape[2]), int(shape[1])
if len(shape) == 3:
if layout == "CHW":
return int(shape[2]), int(shape[1])
if layout == "HWC":
return int(shape[1]), int(shape[0])
if shape:
if len(shape) == 4:
_, d1, d2, d3 = shape
if d1 in (1, 3):
return int(d3), int(d2)
if d3 in (1, 3):
return int(d2), int(d1)
return int(d2), int(d1)
if len(shape) == 3:
d0, d1, d2 = shape
if d0 in (1, 3):
return int(d2), int(d1)
if d2 in (1, 3):
return int(d1), int(d0)
return int(d1), int(d0)
if len(shape) == 2:
h, w = shape
return int(w), int(h)
except Exception:
pass
return 320, 320
def _run_inference(
self, session: ort.InferenceSession, tensor: np.ndarray, header: dict
) -> np.ndarray:
try:
model_type = header.get("model_type")
width, height = self._extract_input_hw(header)
if model_type == "dfine":
input_data = {
"images": tensor.astype(np.float32),
"orig_target_sizes": np.array([[height, width]], dtype=np.int64),
}
else:
input_name = session.get_inputs()[0].name
input_data = {input_name: tensor}
if logger.isEnabledFor(logging.DEBUG):
t_start = time.perf_counter()
outputs = session.run(None, input_data)
if logger.isEnabledFor(logging.DEBUG):
t_after_onnx = time.perf_counter()
if model_type == "yolo-generic" or model_type == "yologeneric":
result = post_process_yolo(outputs, width, height)
elif model_type == "dfine":
result = post_process_dfine(outputs, width, height)
elif model_type == "rfdetr":
result = post_process_rfdetr(outputs)
else:
result = np.zeros((20, 6), dtype=np.float32)
if logger.isEnabledFor(logging.DEBUG):
t_after_post = time.perf_counter()
onnx_ms = (t_after_onnx - t_start) * 1000.0
post_ms = (t_after_post - t_after_onnx) * 1000.0
total_ms = (t_after_post - t_start) * 1000.0
logger.debug(
f"Worker {self.worker_id} timing: onnx={onnx_ms:.2f}ms, post={post_ms:.2f}ms, total={total_ms:.2f}ms"
)
return result.astype(np.float32)
except Exception as e:
logger.error(f"Worker {self.worker_id} ONNX inference failed: {e}")
return self.zero_result
# --- Message handlers ---
def _handle_model_request(self, header: dict) -> List[bytes]:
model_name = header.get("model_name")
if not model_name:
return self._build_error_response("Model request missing model_name")
if self._check_model_exists(model_name):
# Ensure session exists
if self._get_or_create_session(model_name) is not None:
response_header = {
"model_available": True,
"model_loaded": True,
"model_name": model_name,
"message": f"Model {model_name} loaded successfully",
}
else:
response_header = {
"model_available": True,
"model_loaded": False,
"model_name": model_name,
"message": f"Model {model_name} exists but failed to load",
}
else:
response_header = {
"model_available": False,
"model_name": model_name,
"message": f"Model {model_name} not found, please send model data",
}
return [json.dumps(response_header).encode("utf-8")]
def _handle_model_data(self, header: dict, model_data: bytes) -> List[bytes]:
model_name = header.get("model_name")
if not model_name:
return self._build_error_response("Model data missing model_name")
if self._save_model(model_name, model_data):
# Ensure session is created
if self._get_or_create_session(model_name) is not None:
response_header = {
"model_saved": True,
"model_loaded": True,
"model_name": model_name,
"message": f"Model {model_name} saved and loaded successfully",
}
else:
response_header = {
"model_saved": True,
"model_loaded": False,
"model_name": model_name,
"message": f"Model {model_name} saved but failed to load",
}
else:
response_header = {
"model_saved": False,
"model_loaded": False,
"model_name": model_name,
"message": f"Failed to save model {model_name}",
}
return [json.dumps(response_header).encode("utf-8")]
# --- Thread run ---
def run(self) -> None: # pragma: no cover - runtime loop
try:
self.socket = self._create_socket()
logger.info(
f"Worker {self.worker_id} connected REP to endpoint: {self.endpoint}"
)
while True:
try:
frames = self.socket.recv_multipart()
tensor, header = self._decode_request(frames)
if "model_request" in header:
response = self._handle_model_request(header)
self.socket.send_multipart(response)
continue
if "model_data" in header and len(frames) >= 2:
model_data = frames[1]
response = self._handle_model_data(header, model_data)
self.socket.send_multipart(response)
continue
if tensor is not None:
model_name = header.get("model_name")
session = None
if model_name:
session = self._get_or_create_session(model_name)
if session is None:
result = self.zero_result
else:
result = self._run_inference(session, tensor, header)
self.socket.send_multipart(self._build_response(result))
continue
# Unknown message: reply with zeros
self.socket.send_multipart(self._build_response(self.zero_result))
except zmq.Again:
continue
except zmq.ZMQError as e:
logger.error(f"Worker {self.worker_id} ZMQ error: {e}")
# Recreate socket on transient errors
try:
if self.socket:
self.socket.close(linger=0)
finally:
self.socket = self._create_socket()
except Exception as e:
logger.error(f"Worker {self.worker_id} unexpected error: {e}")
try:
self.socket.send_multipart(self._build_error_response(str(e)))
except Exception:
pass
finally:
try:
if self.socket:
self.socket.close(linger=0)
except Exception:
pass
class ZmqOnnxClient:
"""
ZMQ TCP client that runs ONNX inference on received tensors.
"""
def __init__(
self,
endpoint: str = "tcp://*:5555",
model_path: Optional[str] = "AUTO",
providers: Optional[List[str]] = None,
session_options: Optional[ort.SessionOptions] = None,
):
"""
Initialize the ZMQ ONNX client.
Args:
endpoint: ZMQ IPC endpoint to bind to
model_path: Path to ONNX model file or "AUTO" for automatic model management
providers: ONNX Runtime execution providers
session_options: ONNX Runtime session options
"""
self.endpoint = endpoint
self.model_path = model_path
self.current_model = None
self.model_ready = False
self.models_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "models"
)
# Shared ZMQ context and shared session map across workers
self.context = zmq.Context()
self.model_sessions: Dict[str, ort.InferenceSession] = {}
self.model_lock = threading.Lock()
self.providers = providers
# Preallocate zero result for error cases
self.zero_result = np.zeros((20, 6), dtype=np.float32)
logger.info(f"ZMQ ONNX client will start workers on endpoint: {endpoint}")
def start_server(self, num_workers: int = 4) -> None:
workers: list[ZmqOnnxWorker] = []
for i in range(num_workers):
w = ZmqOnnxWorker(
worker_id=i,
context=self.context,
endpoint=self.endpoint,
models_dir=self.models_dir,
model_sessions=self.model_sessions,
model_lock=self.model_lock,
providers=self.providers,
zero_result=self.zero_result,
)
w.start()
workers.append(w)
logger.info(f"Started {num_workers} ZMQ REP workers on backend {self.endpoint}")
try:
for w in workers:
w.join()
except KeyboardInterrupt:
logger.info("Shutting down workers...")
def _check_model_exists(self, model_name: str) -> bool:
"""
Check if a model exists in the models directory.
Args:
model_name: Name of the model file to check
Returns:
True if model exists, False otherwise
"""
model_path = os.path.join(self.models_dir, model_name)
return os.path.exists(model_path)
# These methods remain for compatibility but are unused in worker mode
def _save_model(self, model_name: str, model_data: bytes) -> bool:
"""
Save model data to the models directory.
Args:
model_name: Name of the model file to save
model_data: Binary model data
Returns:
True if model saved successfully, False otherwise
"""
try:
# Ensure models directory exists
os.makedirs(self.models_dir, exist_ok=True)
model_path = os.path.join(self.models_dir, model_name)
logger.info(f"Saving model to: {model_path}")
with open(model_path, "wb") as f:
f.write(model_data)
logger.info(f"Model saved successfully: {model_name}")
return True
except Exception as e:
logger.error(f"Failed to save model {model_name}: {e}")
return False
def _decode_request(self, frames: List[bytes]) -> Tuple[np.ndarray, dict]:
"""
Decode the incoming request frames.
Args:
frames: List of message frames
Returns:
Tuple of (tensor, header_dict)
"""
try:
if len(frames) < 1:
raise ValueError(f"Expected at least 1 frame, got {len(frames)}")
# Parse header
header_bytes = frames[0]
header = json.loads(header_bytes.decode("utf-8"))
if "model_request" in header:
return None, header
if "model_data" in header:
if len(frames) < 2:
raise ValueError(
f"Model data request expected 2 frames, got {len(frames)}"
)
return None, header
if len(frames) < 2:
raise ValueError(f"Tensor request expected 2 frames, got {len(frames)}")
tensor_bytes = frames[1]
shape = tuple(header.get("shape", []))
dtype_str = header.get("dtype", "uint8")
dtype = np.dtype(dtype_str)
tensor = np.frombuffer(tensor_bytes, dtype=dtype).reshape(shape)
return tensor, header
except Exception as e:
logger.error(f"Failed to decode request: {e}")
raise
def _run_inference(self, tensor: np.ndarray, header: dict) -> np.ndarray:
"""
Run ONNX inference on the input tensor.
Args:
tensor: Input tensor
header: Request header containing metadata (e.g., shape, layout)
Returns:
Detection results as numpy array
Raises:
RuntimeError: If no ONNX session is available or inference fails
"""
if self.session is None:
logger.warning("No ONNX session available, returning zero results")
return self.zero_result
try:
# Prepare input for ONNX Runtime
# Determine input spatial size (W, H) from header/shape/layout
model_type = header.get("model_type")
width, height = self._extract_input_hw(header)
if model_type == "dfine":
# DFine model requires both images and orig_target_sizes inputs
input_data = {
"images": tensor.astype(np.float32),
"orig_target_sizes": np.array([[height, width]], dtype=np.int64),
}
else:
# Other models use single input
input_name = self.session.get_inputs()[0].name
input_data = {input_name: tensor}
# Run inference
if logger.isEnabledFor(logging.DEBUG):
t_start = time.perf_counter()
outputs = self.session.run(None, input_data)
if logger.isEnabledFor(logging.DEBUG):
t_after_onnx = time.perf_counter()
if model_type == "yolo-generic" or model_type == "yologeneric":
result = post_process_yolo(outputs, width, height)
elif model_type == "dfine":
result = post_process_dfine(outputs, width, height)
elif model_type == "rfdetr":
result = post_process_rfdetr(outputs)
if logger.isEnabledFor(logging.DEBUG):
t_after_post = time.perf_counter()
onnx_ms = (t_after_onnx - t_start) * 1000.0
post_ms = (t_after_post - t_after_onnx) * 1000.0
total_ms = (t_after_post - t_start) * 1000.0
logger.debug(
f"Inference timing: onnx={onnx_ms:.2f}ms, post={post_ms:.2f}ms, total={total_ms:.2f}ms"
)
# Ensure float32 dtype
result = result.astype(np.float32)
return result
except Exception as e:
logger.error(f"ONNX inference failed: {e}")
return self.zero_result
def _extract_input_hw(self, header: dict) -> Tuple[int, int]:
"""
Extract (width, height) from the header and/or tensor shape, supporting
NHWC/NCHW as well as 3D/4D inputs. Falls back to 320x320 if unknown.
Preference order:
1) Explicit header keys: width/height
2) Use provided layout to interpret shape
3) Heuristics on shape
"""
try:
if "width" in header and "height" in header:
return int(header["width"]), int(header["height"])
shape = tuple(header.get("shape", []))
layout = header.get("layout") or header.get("order")
if layout and shape:
layout = str(layout).upper()
if len(shape) == 4:
if layout == "NCHW":
return int(shape[3]), int(shape[2])
if layout == "NHWC":
return int(shape[2]), int(shape[1])
if len(shape) == 3:
if layout == "CHW":
return int(shape[2]), int(shape[1])
if layout == "HWC":
return int(shape[1]), int(shape[0])
if shape:
if len(shape) == 4:
_, d1, d2, d3 = shape
if d1 in (1, 3):
return int(d3), int(d2)
if d3 in (1, 3):
return int(d2), int(d1)
return int(d2), int(d1)
if len(shape) == 3:
d0, d1, d2 = shape
if d0 in (1, 3):
return int(d2), int(d1)
if d2 in (1, 3):
return int(d1), int(d0)
return int(d1), int(d0)
if len(shape) == 2:
h, w = shape
return int(w), int(h)
except Exception as e:
logger.debug(f"Failed to extract input size from header: {e}")
logger.debug("Falling back to default input size (320x320)")
return 320, 320
def _build_response(self, result: np.ndarray) -> List[bytes]:
"""
Build the response message.
Args:
result: Detection results
Returns:
List of response frames
"""
try:
# Build header
header = {
"shape": list(result.shape),
"dtype": str(result.dtype.name),
"timestamp": time.time(),
}
header_bytes = json.dumps(header).encode("utf-8")
# Convert result to bytes
result_bytes = result.tobytes(order="C")
return [header_bytes, result_bytes]
except Exception as e:
logger.error(f"Failed to build response: {e}")
# Return zero result as fallback
header = {
"shape": [20, 6],
"dtype": "float32",
"error": "Failed to build response",
}
header_bytes = json.dumps(header).encode("utf-8")
result_bytes = self.zero_result.tobytes(order="C")
return [header_bytes, result_bytes]
def _handle_model_request(self, header: dict) -> List[bytes]:
"""
Handle model availability request.
Args:
header: Request header containing model information
Returns:
Response message indicating model availability
"""
model_name = header.get("model_name")
if not model_name:
logger.error("Model request missing model_name")
return self._build_error_response("Model request missing model_name")
logger.info(f"Model availability request for: {model_name}")
if self._check_model_exists(model_name):
logger.info(f"Model {model_name} exists locally")
# Try to load the model
if self._load_model(model_name):
response_header = {
"model_available": True,
"model_loaded": True,
"model_name": model_name,
"message": f"Model {model_name} loaded successfully",
}
else:
response_header = {
"model_available": True,
"model_loaded": False,
"model_name": model_name,
"message": f"Model {model_name} exists but failed to load",
}
else:
logger.info(f"Model {model_name} not found, requesting transfer")
response_header = {
"model_available": False,
"model_name": model_name,
"message": f"Model {model_name} not found, please send model data",
}
return [json.dumps(response_header).encode("utf-8")]
def _handle_model_data(self, header: dict, model_data: bytes) -> List[bytes]:
"""
Handle model data transfer.
Args:
header: Request header containing model information
model_data: Binary model data
Returns:
Response message indicating save success/failure
"""
model_name = header.get("model_name")
if not model_name:
logger.error("Model data missing model_name")
return self._build_error_response("Model data missing model_name")
logger.info(f"Received model data for: {model_name}")
if self._save_model(model_name, model_data):
# Try to load the model
if self._load_model(model_name):
response_header = {
"model_saved": True,
"model_loaded": True,
"model_name": model_name,
"message": f"Model {model_name} saved and loaded successfully",
}
else:
response_header = {
"model_saved": True,
"model_loaded": False,
"model_name": model_name,
"message": f"Model {model_name} saved but failed to load",
}
else:
response_header = {
"model_saved": False,
"model_loaded": False,
"model_name": model_name,
"message": f"Failed to save model {model_name}",
}
return [json.dumps(response_header).encode("utf-8")]
def _build_error_response(self, error_msg: str) -> List[bytes]:
"""Build an error response message."""
error_header = {"error": error_msg}
return [json.dumps(error_header).encode("utf-8")]
# Removed legacy single-thread start_server implementation in favor of worker pool
def _send_error_response(self, error_msg: str):
"""Send an error response to the client."""
try:
error_header = {"shape": [20, 6], "dtype": "float32", "error": error_msg}
error_response = [
json.dumps(error_header).encode("utf-8"),
self.zero_result.tobytes(order="C"),
]
self.socket.send_multipart(error_response)
except Exception as send_error:
logger.error(f"Failed to send error response: {send_error}")
def cleanup(self):
"""Clean up resources."""
try:
if self.socket:
self.socket.close()
self.socket = None
if self.context:
self.context.term()
self.context = None
logger.info("Cleanup completed")
except Exception as e:
logger.error(f"Cleanup error: {e}")
def main():
"""Main function to run the ZMQ ONNX client."""
import argparse
parser = argparse.ArgumentParser(description="ZMQ TCP ONNX Runtime Client")
parser.add_argument(
"--endpoint",
default="tcp://*:5555",
help="ZMQ TCP endpoint (default: tcp://*:5555)",
)
parser.add_argument(
"--model",
default="AUTO",
help="Path to ONNX model file or AUTO for automatic model management",
)
parser.add_argument(
"--providers",
nargs="+",
default=["CoreMLExecutionProvider"],
help="ONNX Runtime execution providers",
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="Number of REP worker threads",
)
parser.add_argument(
"--verbose", "-v", action="store_true", help="Enable verbose logging"
)
args = parser.parse_args()
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
# Create and start client
client = ZmqOnnxClient(
endpoint=args.endpoint, model_path=args.model, providers=args.providers
)
try:
client.start_server(num_workers=args.workers)
except KeyboardInterrupt:
logger.info("Interrupted by user")
finally:
client.cleanup()
if __name__ == "__main__":
main()