mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-10-04 15:13:22 +08:00
Compare commits
5 Commits
live-view-
...
zmq-model-
Author | SHA1 | Date | |
---|---|---|---|
![]() |
7a02a448cb | ||
![]() |
4ab8de91a9 | ||
![]() |
fbcf64d7bd | ||
![]() |
de960285f6 | ||
![]() |
8b78c85bda |
@@ -28,6 +28,9 @@ from frigate.comms.object_detector_signaler import DetectorProxy
|
||||
from frigate.comms.webpush import WebPushClient
|
||||
from frigate.comms.ws import WebSocketClient
|
||||
from frigate.comms.zmq_proxy import ZmqProxy
|
||||
from frigate.comms.zmq_req_router_broker import (
|
||||
ZmqReqRouterBroker,
|
||||
)
|
||||
from frigate.config.camera.updater import CameraConfigUpdatePublisher
|
||||
from frigate.config.config import FrigateConfig
|
||||
from frigate.const import (
|
||||
@@ -307,6 +310,14 @@ class FrigateApp:
|
||||
self.event_metadata_updater = EventMetadataPublisher()
|
||||
self.inter_zmq_proxy = ZmqProxy()
|
||||
self.detection_proxy = DetectorProxy()
|
||||
self.zmq_router_broker: ZmqReqRouterBroker | None = None
|
||||
|
||||
zmq_detectors = [
|
||||
det for det in self.config.detectors.values() if det.type == "zmq"
|
||||
]
|
||||
if any(zmq_detectors):
|
||||
backend_endpoint = zmq_detectors[0].endpoint
|
||||
self.zmq_router_broker = ZmqReqRouterBroker(backend_endpoint)
|
||||
|
||||
def init_onvif(self) -> None:
|
||||
self.onvif_controller = OnvifController(self.config, self.ptz_metrics)
|
||||
@@ -644,6 +655,9 @@ class FrigateApp:
|
||||
self.inter_zmq_proxy.stop()
|
||||
self.detection_proxy.stop()
|
||||
|
||||
if self.zmq_router_broker:
|
||||
self.zmq_router_broker.stop()
|
||||
|
||||
while len(self.detection_shms) > 0:
|
||||
shm = self.detection_shms.pop()
|
||||
shm.close()
|
||||
|
61
frigate/comms/zmq_req_router_broker.py
Normal file
61
frigate/comms/zmq_req_router_broker.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""ZMQ REQ/ROUTER front-end to DEALER/REP back-end broker.
|
||||
|
||||
This module provides a small proxy that:
|
||||
- Binds a ROUTER socket on a fixed local endpoint for REQ clients
|
||||
- Connects a DEALER socket to the user-configured backend endpoint (REP servers)
|
||||
|
||||
Pattern: REQ -> ROUTER === proxy === DEALER -> REP
|
||||
|
||||
The goal is to allow multiple REQ clients and/or multiple backend workers
|
||||
to share a single configured connection, enabling multiple models/runners
|
||||
behind the same broker while keeping local clients stable via constants.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
|
||||
import zmq
|
||||
|
||||
REQ_ROUTER_ENDPOINT = "ipc:///tmp/cache/zmq_detector_router"
|
||||
|
||||
|
||||
class _RouterDealerRunner(threading.Thread):
|
||||
def __init__(self, context: zmq.Context[zmq.Socket], backend_endpoint: str) -> None:
|
||||
super().__init__(name="zmq_router_dealer_broker", daemon=True)
|
||||
self.context = context
|
||||
self.backend_endpoint = backend_endpoint
|
||||
|
||||
def run(self) -> None:
|
||||
frontend = self.context.socket(zmq.ROUTER)
|
||||
frontend.bind(REQ_ROUTER_ENDPOINT)
|
||||
|
||||
backend = self.context.socket(zmq.DEALER)
|
||||
backend.bind(self.backend_endpoint)
|
||||
|
||||
try:
|
||||
zmq.proxy(frontend, backend)
|
||||
except zmq.ZMQError:
|
||||
# Unblocked when context is destroyed in the controller
|
||||
pass
|
||||
|
||||
|
||||
class ZmqReqRouterBroker:
|
||||
"""Starts a ROUTER/DEALER proxy bridging local REQ clients to backend REP.
|
||||
|
||||
- ROUTER binds to REQ_ROUTER_ENDPOINT (constant, local)
|
||||
- DEALER connects to the provided backend_endpoint (user-configured)
|
||||
"""
|
||||
|
||||
def __init__(self, backend_endpoint: str) -> None:
|
||||
self.backend_endpoint = backend_endpoint
|
||||
self.context = zmq.Context()
|
||||
self.runner = _RouterDealerRunner(self.context, backend_endpoint)
|
||||
self.runner.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
# Destroying the context signals the proxy to stop
|
||||
try:
|
||||
self.context.destroy()
|
||||
finally:
|
||||
self.runner.join()
|
@@ -1,13 +1,17 @@
|
||||
"""Base runner implementation for ONNX models."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import zmq
|
||||
|
||||
from frigate.comms.zmq_req_router_broker import REQ_ROUTER_ENDPOINT
|
||||
from frigate.util.model import get_ort_providers
|
||||
from frigate.util.rknn_converter import auto_convert_model, is_rknn_compatible
|
||||
|
||||
@@ -301,6 +305,187 @@ class OpenVINOModelRunner(BaseModelRunner):
|
||||
return outputs
|
||||
|
||||
|
||||
class ZmqIpcRunner(BaseModelRunner):
|
||||
"""Runner that forwards inference over ZMQ REQ/ROUTER to backend workers.
|
||||
|
||||
This allows reusing the same interface as local runners while delegating
|
||||
inference to the external ZMQ workers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
model_type: str,
|
||||
request_timeout_ms: int = 200,
|
||||
linger_ms: int = 0,
|
||||
endpoint: str = REQ_ROUTER_ENDPOINT,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.model_path = model_path
|
||||
self.model_name = os.path.basename(model_path)
|
||||
self._endpoint = endpoint
|
||||
self._context = zmq.Context()
|
||||
self._socket = self._context.socket(zmq.REQ)
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, request_timeout_ms)
|
||||
self._socket.setsockopt(zmq.SNDTIMEO, request_timeout_ms)
|
||||
self._socket.setsockopt(zmq.LINGER, linger_ms)
|
||||
self._socket.connect(self._endpoint)
|
||||
self._model_ready = False
|
||||
self._io_lock = threading.Lock()
|
||||
|
||||
@staticmethod
|
||||
def is_complex_model(model_type: str) -> bool:
|
||||
# Import here to avoid circular imports
|
||||
from frigate.detectors.detector_config import ModelTypeEnum
|
||||
from frigate.embeddings.types import EnrichmentModelTypeEnum
|
||||
|
||||
return model_type in [
|
||||
ModelTypeEnum.yolonas.value,
|
||||
EnrichmentModelTypeEnum.paddleocr.value,
|
||||
EnrichmentModelTypeEnum.jina_v1.value,
|
||||
EnrichmentModelTypeEnum.jina_v2.value,
|
||||
]
|
||||
|
||||
def get_input_names(self) -> list[str]:
|
||||
if "vision" in self.model_name:
|
||||
return ["pixel_values"]
|
||||
elif "arcface" in self.model_name:
|
||||
return ["data"]
|
||||
else:
|
||||
return ["input"]
|
||||
|
||||
def get_input_width(self) -> int:
|
||||
# Not known/required for ZMQ forwarding
|
||||
return -1
|
||||
|
||||
def _build_header(self, tensor_input: np.ndarray) -> bytes:
|
||||
header: dict[str, object] = {
|
||||
"shape": list(tensor_input.shape),
|
||||
"dtype": str(tensor_input.dtype.name),
|
||||
"model_type": self.model_type,
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
return json.dumps(header).encode("utf-8")
|
||||
|
||||
def _decode_response(self, frames: list[bytes]) -> np.ndarray:
|
||||
if len(frames) == 1:
|
||||
buf = frames[0]
|
||||
if len(buf) != 20 * 6 * 4:
|
||||
raise ValueError(f"Unexpected payload size: {len(buf)}")
|
||||
return np.frombuffer(buf, dtype=np.float32).reshape((20, 6))
|
||||
|
||||
if len(frames) >= 2:
|
||||
header = json.loads(frames[0].decode("utf-8"))
|
||||
shape = tuple(header.get("shape", []))
|
||||
dtype = np.dtype(header.get("dtype", "float32"))
|
||||
return np.frombuffer(frames[1], dtype=dtype).reshape(shape)
|
||||
|
||||
raise ValueError("Empty or malformed reply from ZMQ detector")
|
||||
|
||||
def run(self, input: dict[str, np.ndarray]) -> np.ndarray | None:
|
||||
if not self._model_ready:
|
||||
if not self.ensure_model_ready(self.model_path):
|
||||
raise TimeoutError("ZMQ detector model is not ready after transfer")
|
||||
self._model_ready = True
|
||||
|
||||
input_name = next(iter(input))
|
||||
tensor = input[input_name]
|
||||
header = self._build_header(tensor)
|
||||
payload = memoryview(tensor.tobytes(order="C"))
|
||||
try:
|
||||
with self._io_lock:
|
||||
self._socket.send_multipart([header, payload])
|
||||
frames = self._socket.recv_multipart()
|
||||
except zmq.Again as e:
|
||||
raise TimeoutError("ZMQ detector request timed out") from e
|
||||
except zmq.ZMQError as e:
|
||||
raise RuntimeError(f"ZMQ error: {e}") from e
|
||||
|
||||
return self._decode_response(frames)
|
||||
|
||||
def ensure_model_ready(self, model_path: str) -> bool:
|
||||
"""Ensure the remote has the model and it is loaded.
|
||||
|
||||
1) Send model_request with model_name
|
||||
2) If not available, send model_data with the file contents
|
||||
3) Wait for loaded confirmation
|
||||
Returns True on success.
|
||||
"""
|
||||
# Check model availability
|
||||
req = {"model_request": True, "model_name": self.model_name}
|
||||
with self._io_lock:
|
||||
self._socket.send_multipart([json.dumps(req).encode("utf-8")])
|
||||
|
||||
# Temporarily extend timeout for model ops
|
||||
original_rcv = self._socket.getsockopt(zmq.RCVTIMEO)
|
||||
try:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, max(30000, int(original_rcv or 0)))
|
||||
resp_frames = self._socket.recv_multipart()
|
||||
except zmq.Again:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv)
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
if len(resp_frames) != 1:
|
||||
return False
|
||||
resp = json.loads(resp_frames[0].decode("utf-8"))
|
||||
if resp.get("model_available") and resp.get("model_loaded"):
|
||||
logger.info(f"ZMQ detector model {self.model_name} is ready")
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(model_path, "rb") as f:
|
||||
model_bytes = f.read()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
header = {"model_data": True, "model_name": self.model_name}
|
||||
with self._io_lock:
|
||||
self._socket.send_multipart(
|
||||
[json.dumps(header).encode("utf-8"), model_bytes]
|
||||
)
|
||||
|
||||
original_rcv2 = self._socket.getsockopt(zmq.RCVTIMEO)
|
||||
try:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, max(30000, int(original_rcv2 or 0)))
|
||||
resp2 = self._socket.recv_multipart()
|
||||
except zmq.Again:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv2)
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv2)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
if len(resp2) != 1:
|
||||
return False
|
||||
j = json.loads(resp2[0].decode("utf-8"))
|
||||
return bool(j.get("model_saved") and j.get("model_loaded"))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
if self._socket is not None:
|
||||
self._socket.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._context is not None:
|
||||
self._context.term()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class RKNNModelRunner(BaseModelRunner):
|
||||
"""Run RKNN models for embeddings."""
|
||||
|
||||
@@ -415,6 +600,10 @@ def get_optimized_runner(
|
||||
if rknn_path:
|
||||
return RKNNModelRunner(rknn_path)
|
||||
|
||||
if device == "ZMQ" and not ZmqIpcRunner.is_complex_model(model_type):
|
||||
logger.info(f"Using ZMQ detector model {model_path}")
|
||||
return ZmqIpcRunner(model_path, model_type, **kwargs)
|
||||
|
||||
providers, options = get_ort_providers(device == "CPU", device, **kwargs)
|
||||
|
||||
if providers[0] == "CPUExecutionProvider":
|
||||
|
@@ -8,7 +8,9 @@ import zmq
|
||||
from pydantic import Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
from frigate.comms.zmq_req_router_broker import REQ_ROUTER_ENDPOINT
|
||||
from frigate.detectors.detection_api import DetectionApi
|
||||
from frigate.detectors.detection_runners import ZmqIpcRunner
|
||||
from frigate.detectors.detector_config import BaseDetectorConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -49,9 +51,7 @@ class ZmqIpcDetector(DetectionApi):
|
||||
On any error or timeout, this detector returns a zero array of shape (20, 6).
|
||||
|
||||
Model Management:
|
||||
- On initialization, sends model request to check if model is available
|
||||
- If model not available, sends model data via ZMQ
|
||||
- Only starts inference after model is ready
|
||||
- Model transfer/availability is handled by the runner automatically
|
||||
"""
|
||||
|
||||
type_key = DETECTOR_KEY
|
||||
@@ -60,21 +60,23 @@ class ZmqIpcDetector(DetectionApi):
|
||||
super().__init__(detector_config)
|
||||
|
||||
self._context = zmq.Context()
|
||||
self._endpoint = detector_config.endpoint
|
||||
self._endpoint = REQ_ROUTER_ENDPOINT
|
||||
self._request_timeout_ms = detector_config.request_timeout_ms
|
||||
self._linger_ms = detector_config.linger_ms
|
||||
self._socket = None
|
||||
self._create_socket()
|
||||
|
||||
# Model management
|
||||
self._model_ready = False
|
||||
self._model_name = self._get_model_name()
|
||||
|
||||
# Initialize model if needed
|
||||
self._initialize_model()
|
||||
|
||||
# Preallocate zero result for error paths
|
||||
self._zero_result = np.zeros((20, 6), np.float32)
|
||||
self._runner = ZmqIpcRunner(
|
||||
model_path=self.detector_config.model.path,
|
||||
model_type=str(self.detector_config.model.model_type.value),
|
||||
request_timeout_ms=self._request_timeout_ms,
|
||||
linger_ms=self._linger_ms,
|
||||
endpoint=self._endpoint,
|
||||
)
|
||||
|
||||
def _create_socket(self) -> None:
|
||||
if self._socket is not None:
|
||||
@@ -96,167 +98,12 @@ class ZmqIpcDetector(DetectionApi):
|
||||
model_path = self.detector_config.model.path
|
||||
return os.path.basename(model_path)
|
||||
|
||||
def _initialize_model(self) -> None:
|
||||
"""Initialize the model by checking availability and transferring if needed."""
|
||||
try:
|
||||
logger.info(f"Initializing model: {self._model_name}")
|
||||
|
||||
# Check if model is available and transfer if needed
|
||||
if self._check_and_transfer_model():
|
||||
logger.info(f"Model {self._model_name} is ready")
|
||||
self._model_ready = True
|
||||
else:
|
||||
logger.error(f"Failed to initialize model {self._model_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model: {e}")
|
||||
|
||||
def _check_and_transfer_model(self) -> bool:
|
||||
"""Check if model is available and transfer if needed in one atomic operation."""
|
||||
try:
|
||||
# Send model availability request
|
||||
header = {"model_request": True, "model_name": self._model_name}
|
||||
header_bytes = json.dumps(header).encode("utf-8")
|
||||
|
||||
self._socket.send_multipart([header_bytes])
|
||||
|
||||
# Temporarily increase timeout for model operations
|
||||
original_timeout = self._socket.getsockopt(zmq.RCVTIMEO)
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, 30000)
|
||||
|
||||
try:
|
||||
response_frames = self._socket.recv_multipart()
|
||||
finally:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, original_timeout)
|
||||
|
||||
if len(response_frames) == 1:
|
||||
try:
|
||||
response = json.loads(response_frames[0].decode("utf-8"))
|
||||
model_available = response.get("model_available", False)
|
||||
model_loaded = response.get("model_loaded", False)
|
||||
|
||||
if model_available and model_loaded:
|
||||
return True
|
||||
elif model_available and not model_loaded:
|
||||
logger.error("Model exists but failed to load")
|
||||
return False
|
||||
else:
|
||||
return self._send_model_data()
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Received non-JSON response for model availability check"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logger.warning(
|
||||
"Received unexpected response format for model availability check"
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check and transfer model: {e}")
|
||||
return False
|
||||
|
||||
def _check_model_availability(self) -> bool:
|
||||
"""Check if the model is available on the detector."""
|
||||
try:
|
||||
# Send model availability request
|
||||
header = {"model_request": True, "model_name": self._model_name}
|
||||
header_bytes = json.dumps(header).encode("utf-8")
|
||||
|
||||
self._socket.send_multipart([header_bytes])
|
||||
|
||||
# Receive response
|
||||
response_frames = self._socket.recv_multipart()
|
||||
|
||||
# Check if this is a JSON response (model management)
|
||||
if len(response_frames) == 1:
|
||||
try:
|
||||
response = json.loads(response_frames[0].decode("utf-8"))
|
||||
model_available = response.get("model_available", False)
|
||||
model_loaded = response.get("model_loaded", False)
|
||||
logger.debug(
|
||||
f"Model availability check: available={model_available}, loaded={model_loaded}"
|
||||
)
|
||||
return model_available and model_loaded
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Received non-JSON response for model availability check"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
logger.warning(
|
||||
"Received unexpected response format for model availability check"
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check model availability: {e}")
|
||||
return False
|
||||
|
||||
def _send_model_data(self) -> bool:
|
||||
"""Send model data to the detector."""
|
||||
try:
|
||||
model_path = self.detector_config.model.path
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f"Model file not found: {model_path}")
|
||||
return False
|
||||
|
||||
logger.info(f"Transferring model to detector: {self._model_name}")
|
||||
with open(model_path, "rb") as f:
|
||||
model_data = f.read()
|
||||
|
||||
header = {"model_data": True, "model_name": self._model_name}
|
||||
header_bytes = json.dumps(header).encode("utf-8")
|
||||
|
||||
self._socket.send_multipart([header_bytes, model_data])
|
||||
|
||||
# Temporarily increase timeout for model loading (can take several seconds)
|
||||
original_timeout = self._socket.getsockopt(zmq.RCVTIMEO)
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, 30000)
|
||||
|
||||
try:
|
||||
# Receive response
|
||||
response_frames = self._socket.recv_multipart()
|
||||
finally:
|
||||
# Restore original timeout
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, original_timeout)
|
||||
|
||||
# Check if this is a JSON response (model management)
|
||||
if len(response_frames) == 1:
|
||||
try:
|
||||
response = json.loads(response_frames[0].decode("utf-8"))
|
||||
model_saved = response.get("model_saved", False)
|
||||
model_loaded = response.get("model_loaded", False)
|
||||
if model_saved and model_loaded:
|
||||
logger.info(
|
||||
f"Model {self._model_name} transferred and loaded successfully"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Model transfer failed: saved={model_saved}, loaded={model_loaded}"
|
||||
)
|
||||
return model_saved and model_loaded
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Received non-JSON response for model data transfer")
|
||||
return False
|
||||
else:
|
||||
logger.warning(
|
||||
"Received unexpected response format for model data transfer"
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send model data: {e}")
|
||||
return False
|
||||
|
||||
def _build_header(self, tensor_input: np.ndarray) -> bytes:
|
||||
header: dict[str, Any] = {
|
||||
"shape": list(tensor_input.shape),
|
||||
"dtype": str(tensor_input.dtype.name),
|
||||
"model_type": str(self.detector_config.model.model_type.name),
|
||||
"model_name": self._model_name,
|
||||
}
|
||||
return json.dumps(header).encode("utf-8")
|
||||
|
||||
@@ -285,42 +132,11 @@ class ZmqIpcDetector(DetectionApi):
|
||||
return self._zero_result
|
||||
|
||||
def detect_raw(self, tensor_input: np.ndarray) -> np.ndarray:
|
||||
if not self._model_ready:
|
||||
logger.warning("Model not ready, returning zero detections")
|
||||
return self._zero_result
|
||||
|
||||
try:
|
||||
header_bytes = self._build_header(tensor_input)
|
||||
payload_bytes = memoryview(tensor_input.tobytes(order="C"))
|
||||
|
||||
# Send request
|
||||
self._socket.send_multipart([header_bytes, payload_bytes])
|
||||
|
||||
# Receive reply
|
||||
reply_frames = self._socket.recv_multipart()
|
||||
detections = self._decode_response(reply_frames)
|
||||
|
||||
# Ensure output shape and dtype are exactly as expected
|
||||
return detections
|
||||
except zmq.Again:
|
||||
# Timeout
|
||||
logger.debug("ZMQ detector request timed out; resetting socket")
|
||||
try:
|
||||
self._create_socket()
|
||||
self._initialize_model()
|
||||
except Exception:
|
||||
pass
|
||||
return self._zero_result
|
||||
except zmq.ZMQError as exc:
|
||||
logger.error(f"ZMQ detector ZMQError: {exc}; resetting socket")
|
||||
try:
|
||||
self._create_socket()
|
||||
self._initialize_model()
|
||||
except Exception:
|
||||
pass
|
||||
return self._zero_result
|
||||
result = self._runner.run({"input": tensor_input})
|
||||
return result if isinstance(result, np.ndarray) else self._zero_result
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(f"ZMQ detector unexpected error: {exc}")
|
||||
logger.error(f"ZMQ IPC runner error: {exc}")
|
||||
return self._zero_result
|
||||
|
||||
def __del__(self) -> None: # pragma: no cover - best-effort cleanup
|
||||
|
838
frigate/detectors/zmq_client.py
Normal file
838
frigate/detectors/zmq_client.py
Normal file
@@ -0,0 +1,838 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
ZMQ TCP ONNX Runtime Client
|
||||
|
||||
This client connects to the ZMQ TCP proxy, accepts tensor inputs,
|
||||
runs inference via ONNX Runtime, and returns detection results.
|
||||
|
||||
Protocol:
|
||||
- Receives multipart messages: [header_json_bytes, tensor_bytes]
|
||||
- Header contains shape and dtype information
|
||||
- Runs ONNX inference on the tensor
|
||||
- Returns results in the expected format: [20, 6] float32 array
|
||||
|
||||
Note: Timeouts are normal when Frigate has no motion to detect.
|
||||
The server will continue running and waiting for requests.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import zmq
|
||||
from model_util import post_process_dfine, post_process_rfdetr, post_process_yolo
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ZmqOnnxWorker(threading.Thread):
|
||||
"""
|
||||
A worker thread that connects a REP socket to the endpoint and processes
|
||||
requests using a shared model session map. This mirrors the single-runner
|
||||
logic, but the ONNX Runtime session is fetched from the shared map, and
|
||||
created on-demand if missing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_id: int,
|
||||
context: zmq.Context,
|
||||
endpoint: str,
|
||||
models_dir: str,
|
||||
model_sessions: Dict[str, ort.InferenceSession],
|
||||
model_lock: threading.Lock,
|
||||
providers: Optional[List[str]],
|
||||
zero_result: np.ndarray,
|
||||
) -> None:
|
||||
super().__init__(name=f"onnx_worker_{worker_id}", daemon=True)
|
||||
self.worker_id = worker_id
|
||||
self.context = context
|
||||
self.endpoint = self._normalize_endpoint(endpoint)
|
||||
self.models_dir = models_dir
|
||||
self.model_sessions = model_sessions
|
||||
self.model_lock = model_lock
|
||||
self.providers = providers
|
||||
self.zero_result = zero_result
|
||||
self.socket: Optional[zmq.Socket] = None
|
||||
|
||||
def _normalize_endpoint(self, endpoint: str) -> str:
|
||||
if endpoint.startswith("tcp://*:"):
|
||||
port = endpoint.split(":", 2)[-1]
|
||||
return f"tcp://127.0.0.1:{port}"
|
||||
return endpoint
|
||||
|
||||
# --- ZMQ helpers ---
|
||||
def _create_socket(self) -> zmq.Socket:
|
||||
sock = self.context.socket(zmq.REP)
|
||||
sock.setsockopt(zmq.RCVTIMEO, 5000)
|
||||
sock.setsockopt(zmq.SNDTIMEO, 5000)
|
||||
sock.setsockopt(zmq.LINGER, 0)
|
||||
sock.connect(self.endpoint)
|
||||
return sock
|
||||
|
||||
def _decode_request(self, frames: List[bytes]) -> Tuple[Optional[np.ndarray], dict]:
|
||||
if len(frames) < 1:
|
||||
raise ValueError(f"Expected at least 1 frame, got {len(frames)}")
|
||||
|
||||
header_bytes = frames[0]
|
||||
header = json.loads(header_bytes.decode("utf-8"))
|
||||
|
||||
if "model_request" in header:
|
||||
return None, header
|
||||
if "model_data" in header:
|
||||
return None, header
|
||||
if len(frames) < 2:
|
||||
raise ValueError(f"Tensor request expected 2 frames, got {len(frames)}")
|
||||
|
||||
tensor_bytes = frames[1]
|
||||
shape = tuple(header.get("shape", []))
|
||||
dtype_str = header.get("dtype", "uint8")
|
||||
|
||||
dtype = np.dtype(dtype_str)
|
||||
tensor = np.frombuffer(tensor_bytes, dtype=dtype).reshape(shape)
|
||||
return tensor, header
|
||||
|
||||
def _build_response(self, result: np.ndarray) -> List[bytes]:
|
||||
header = {
|
||||
"shape": list(result.shape),
|
||||
"dtype": str(result.dtype.name),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
return [json.dumps(header).encode("utf-8"), result.tobytes(order="C")]
|
||||
|
||||
def _build_error_response(self, error_msg: str) -> List[bytes]:
|
||||
error_header = {"shape": [20, 6], "dtype": "float32", "error": error_msg}
|
||||
return [
|
||||
json.dumps(error_header).encode("utf-8"),
|
||||
self.zero_result.tobytes(order="C"),
|
||||
]
|
||||
|
||||
# --- Model/session helpers ---
|
||||
def _check_model_exists(self, model_name: str) -> bool:
|
||||
return os.path.exists(os.path.join(self.models_dir, model_name))
|
||||
|
||||
def _save_model(self, model_name: str, model_data: bytes) -> bool:
|
||||
try:
|
||||
os.makedirs(self.models_dir, exist_ok=True)
|
||||
with open(os.path.join(self.models_dir, model_name), "wb") as f:
|
||||
f.write(model_data)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Worker {self.worker_id} failed to save model {model_name}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
def _get_or_create_session(self, model_name: str) -> Optional[ort.InferenceSession]:
|
||||
with self.model_lock:
|
||||
session = self.model_sessions.get(model_name)
|
||||
if session is not None:
|
||||
return session
|
||||
try:
|
||||
providers = self.providers or ["CoreMLExecutionProvider"]
|
||||
session = ort.InferenceSession(
|
||||
os.path.join(self.models_dir, model_name), providers=providers
|
||||
)
|
||||
self.model_sessions[model_name] = session
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Worker {self.worker_id} failed to load model {model_name}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
# --- Inference helpers ---
|
||||
def _extract_input_hw(self, header: dict) -> Tuple[int, int]:
|
||||
try:
|
||||
if "width" in header and "height" in header:
|
||||
return int(header["width"]), int(header["height"])
|
||||
shape = tuple(header.get("shape", []))
|
||||
layout = header.get("layout") or header.get("order")
|
||||
if layout and shape:
|
||||
layout = str(layout).upper()
|
||||
if len(shape) == 4:
|
||||
if layout == "NCHW":
|
||||
return int(shape[3]), int(shape[2])
|
||||
if layout == "NHWC":
|
||||
return int(shape[2]), int(shape[1])
|
||||
if len(shape) == 3:
|
||||
if layout == "CHW":
|
||||
return int(shape[2]), int(shape[1])
|
||||
if layout == "HWC":
|
||||
return int(shape[1]), int(shape[0])
|
||||
if shape:
|
||||
if len(shape) == 4:
|
||||
_, d1, d2, d3 = shape
|
||||
if d1 in (1, 3):
|
||||
return int(d3), int(d2)
|
||||
if d3 in (1, 3):
|
||||
return int(d2), int(d1)
|
||||
return int(d2), int(d1)
|
||||
if len(shape) == 3:
|
||||
d0, d1, d2 = shape
|
||||
if d0 in (1, 3):
|
||||
return int(d2), int(d1)
|
||||
if d2 in (1, 3):
|
||||
return int(d1), int(d0)
|
||||
return int(d1), int(d0)
|
||||
if len(shape) == 2:
|
||||
h, w = shape
|
||||
return int(w), int(h)
|
||||
except Exception:
|
||||
pass
|
||||
return 320, 320
|
||||
|
||||
def _run_inference(
|
||||
self, session: ort.InferenceSession, tensor: np.ndarray, header: dict
|
||||
) -> np.ndarray:
|
||||
try:
|
||||
model_type = header.get("model_type")
|
||||
width, height = self._extract_input_hw(header)
|
||||
|
||||
if model_type == "dfine":
|
||||
input_data = {
|
||||
"images": tensor.astype(np.float32),
|
||||
"orig_target_sizes": np.array([[height, width]], dtype=np.int64),
|
||||
}
|
||||
else:
|
||||
input_name = session.get_inputs()[0].name
|
||||
input_data = {input_name: tensor}
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
t_start = time.perf_counter()
|
||||
|
||||
outputs = session.run(None, input_data)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
t_after_onnx = time.perf_counter()
|
||||
|
||||
if model_type == "yolo-generic" or model_type == "yologeneric":
|
||||
result = post_process_yolo(outputs, width, height)
|
||||
elif model_type == "dfine":
|
||||
result = post_process_dfine(outputs, width, height)
|
||||
elif model_type == "rfdetr":
|
||||
result = post_process_rfdetr(outputs)
|
||||
else:
|
||||
result = np.zeros((20, 6), dtype=np.float32)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
t_after_post = time.perf_counter()
|
||||
onnx_ms = (t_after_onnx - t_start) * 1000.0
|
||||
post_ms = (t_after_post - t_after_onnx) * 1000.0
|
||||
total_ms = (t_after_post - t_start) * 1000.0
|
||||
logger.debug(
|
||||
f"Worker {self.worker_id} timing: onnx={onnx_ms:.2f}ms, post={post_ms:.2f}ms, total={total_ms:.2f}ms"
|
||||
)
|
||||
|
||||
return result.astype(np.float32)
|
||||
except Exception as e:
|
||||
logger.error(f"Worker {self.worker_id} ONNX inference failed: {e}")
|
||||
return self.zero_result
|
||||
|
||||
# --- Message handlers ---
|
||||
def _handle_model_request(self, header: dict) -> List[bytes]:
|
||||
model_name = header.get("model_name")
|
||||
if not model_name:
|
||||
return self._build_error_response("Model request missing model_name")
|
||||
if self._check_model_exists(model_name):
|
||||
# Ensure session exists
|
||||
if self._get_or_create_session(model_name) is not None:
|
||||
response_header = {
|
||||
"model_available": True,
|
||||
"model_loaded": True,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} loaded successfully",
|
||||
}
|
||||
else:
|
||||
response_header = {
|
||||
"model_available": True,
|
||||
"model_loaded": False,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} exists but failed to load",
|
||||
}
|
||||
else:
|
||||
response_header = {
|
||||
"model_available": False,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} not found, please send model data",
|
||||
}
|
||||
return [json.dumps(response_header).encode("utf-8")]
|
||||
|
||||
def _handle_model_data(self, header: dict, model_data: bytes) -> List[bytes]:
|
||||
model_name = header.get("model_name")
|
||||
if not model_name:
|
||||
return self._build_error_response("Model data missing model_name")
|
||||
if self._save_model(model_name, model_data):
|
||||
# Ensure session is created
|
||||
if self._get_or_create_session(model_name) is not None:
|
||||
response_header = {
|
||||
"model_saved": True,
|
||||
"model_loaded": True,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} saved and loaded successfully",
|
||||
}
|
||||
else:
|
||||
response_header = {
|
||||
"model_saved": True,
|
||||
"model_loaded": False,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} saved but failed to load",
|
||||
}
|
||||
else:
|
||||
response_header = {
|
||||
"model_saved": False,
|
||||
"model_loaded": False,
|
||||
"model_name": model_name,
|
||||
"message": f"Failed to save model {model_name}",
|
||||
}
|
||||
return [json.dumps(response_header).encode("utf-8")]
|
||||
|
||||
# --- Thread run ---
|
||||
def run(self) -> None: # pragma: no cover - runtime loop
|
||||
try:
|
||||
self.socket = self._create_socket()
|
||||
logger.info(
|
||||
f"Worker {self.worker_id} connected REP to endpoint: {self.endpoint}"
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
frames = self.socket.recv_multipart()
|
||||
tensor, header = self._decode_request(frames)
|
||||
|
||||
if "model_request" in header:
|
||||
response = self._handle_model_request(header)
|
||||
self.socket.send_multipart(response)
|
||||
continue
|
||||
if "model_data" in header and len(frames) >= 2:
|
||||
model_data = frames[1]
|
||||
response = self._handle_model_data(header, model_data)
|
||||
self.socket.send_multipart(response)
|
||||
continue
|
||||
if tensor is not None:
|
||||
model_name = header.get("model_name")
|
||||
session = None
|
||||
if model_name:
|
||||
session = self._get_or_create_session(model_name)
|
||||
if session is None:
|
||||
result = self.zero_result
|
||||
else:
|
||||
result = self._run_inference(session, tensor, header)
|
||||
self.socket.send_multipart(self._build_response(result))
|
||||
continue
|
||||
|
||||
# Unknown message: reply with zeros
|
||||
self.socket.send_multipart(self._build_response(self.zero_result))
|
||||
except zmq.Again:
|
||||
continue
|
||||
except zmq.ZMQError as e:
|
||||
logger.error(f"Worker {self.worker_id} ZMQ error: {e}")
|
||||
# Recreate socket on transient errors
|
||||
try:
|
||||
if self.socket:
|
||||
self.socket.close(linger=0)
|
||||
finally:
|
||||
self.socket = self._create_socket()
|
||||
except Exception as e:
|
||||
logger.error(f"Worker {self.worker_id} unexpected error: {e}")
|
||||
try:
|
||||
self.socket.send_multipart(self._build_error_response(str(e)))
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
if self.socket:
|
||||
self.socket.close(linger=0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class ZmqOnnxClient:
|
||||
"""
|
||||
ZMQ TCP client that runs ONNX inference on received tensors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str = "tcp://*:5555",
|
||||
model_path: Optional[str] = "AUTO",
|
||||
providers: Optional[List[str]] = None,
|
||||
session_options: Optional[ort.SessionOptions] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the ZMQ ONNX client.
|
||||
|
||||
Args:
|
||||
endpoint: ZMQ IPC endpoint to bind to
|
||||
model_path: Path to ONNX model file or "AUTO" for automatic model management
|
||||
providers: ONNX Runtime execution providers
|
||||
session_options: ONNX Runtime session options
|
||||
"""
|
||||
self.endpoint = endpoint
|
||||
self.model_path = model_path
|
||||
self.current_model = None
|
||||
self.model_ready = False
|
||||
self.models_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "models"
|
||||
)
|
||||
|
||||
# Shared ZMQ context and shared session map across workers
|
||||
self.context = zmq.Context()
|
||||
self.model_sessions: Dict[str, ort.InferenceSession] = {}
|
||||
self.model_lock = threading.Lock()
|
||||
self.providers = providers
|
||||
|
||||
# Preallocate zero result for error cases
|
||||
self.zero_result = np.zeros((20, 6), dtype=np.float32)
|
||||
|
||||
logger.info(f"ZMQ ONNX client will start workers on endpoint: {endpoint}")
|
||||
|
||||
def start_server(self, num_workers: int = 4) -> None:
|
||||
workers: list[ZmqOnnxWorker] = []
|
||||
for i in range(num_workers):
|
||||
w = ZmqOnnxWorker(
|
||||
worker_id=i,
|
||||
context=self.context,
|
||||
endpoint=self.endpoint,
|
||||
models_dir=self.models_dir,
|
||||
model_sessions=self.model_sessions,
|
||||
model_lock=self.model_lock,
|
||||
providers=self.providers,
|
||||
zero_result=self.zero_result,
|
||||
)
|
||||
w.start()
|
||||
workers.append(w)
|
||||
logger.info(f"Started {num_workers} ZMQ REP workers on backend {self.endpoint}")
|
||||
try:
|
||||
for w in workers:
|
||||
w.join()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Shutting down workers...")
|
||||
|
||||
def _check_model_exists(self, model_name: str) -> bool:
|
||||
"""
|
||||
Check if a model exists in the models directory.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model file to check
|
||||
|
||||
Returns:
|
||||
True if model exists, False otherwise
|
||||
"""
|
||||
model_path = os.path.join(self.models_dir, model_name)
|
||||
return os.path.exists(model_path)
|
||||
|
||||
# These methods remain for compatibility but are unused in worker mode
|
||||
|
||||
def _save_model(self, model_name: str, model_data: bytes) -> bool:
|
||||
"""
|
||||
Save model data to the models directory.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model file to save
|
||||
model_data: Binary model data
|
||||
|
||||
Returns:
|
||||
True if model saved successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Ensure models directory exists
|
||||
os.makedirs(self.models_dir, exist_ok=True)
|
||||
|
||||
model_path = os.path.join(self.models_dir, model_name)
|
||||
logger.info(f"Saving model to: {model_path}")
|
||||
|
||||
with open(model_path, "wb") as f:
|
||||
f.write(model_data)
|
||||
|
||||
logger.info(f"Model saved successfully: {model_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save model {model_name}: {e}")
|
||||
return False
|
||||
|
||||
def _decode_request(self, frames: List[bytes]) -> Tuple[np.ndarray, dict]:
|
||||
"""
|
||||
Decode the incoming request frames.
|
||||
|
||||
Args:
|
||||
frames: List of message frames
|
||||
|
||||
Returns:
|
||||
Tuple of (tensor, header_dict)
|
||||
"""
|
||||
try:
|
||||
if len(frames) < 1:
|
||||
raise ValueError(f"Expected at least 1 frame, got {len(frames)}")
|
||||
|
||||
# Parse header
|
||||
header_bytes = frames[0]
|
||||
header = json.loads(header_bytes.decode("utf-8"))
|
||||
|
||||
if "model_request" in header:
|
||||
return None, header
|
||||
|
||||
if "model_data" in header:
|
||||
if len(frames) < 2:
|
||||
raise ValueError(
|
||||
f"Model data request expected 2 frames, got {len(frames)}"
|
||||
)
|
||||
return None, header
|
||||
|
||||
if len(frames) < 2:
|
||||
raise ValueError(f"Tensor request expected 2 frames, got {len(frames)}")
|
||||
|
||||
tensor_bytes = frames[1]
|
||||
shape = tuple(header.get("shape", []))
|
||||
dtype_str = header.get("dtype", "uint8")
|
||||
|
||||
dtype = np.dtype(dtype_str)
|
||||
tensor = np.frombuffer(tensor_bytes, dtype=dtype).reshape(shape)
|
||||
return tensor, header
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decode request: {e}")
|
||||
raise
|
||||
|
||||
def _run_inference(self, tensor: np.ndarray, header: dict) -> np.ndarray:
|
||||
"""
|
||||
Run ONNX inference on the input tensor.
|
||||
|
||||
Args:
|
||||
tensor: Input tensor
|
||||
header: Request header containing metadata (e.g., shape, layout)
|
||||
|
||||
Returns:
|
||||
Detection results as numpy array
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no ONNX session is available or inference fails
|
||||
"""
|
||||
if self.session is None:
|
||||
logger.warning("No ONNX session available, returning zero results")
|
||||
return self.zero_result
|
||||
|
||||
try:
|
||||
# Prepare input for ONNX Runtime
|
||||
# Determine input spatial size (W, H) from header/shape/layout
|
||||
model_type = header.get("model_type")
|
||||
width, height = self._extract_input_hw(header)
|
||||
|
||||
if model_type == "dfine":
|
||||
# DFine model requires both images and orig_target_sizes inputs
|
||||
input_data = {
|
||||
"images": tensor.astype(np.float32),
|
||||
"orig_target_sizes": np.array([[height, width]], dtype=np.int64),
|
||||
}
|
||||
else:
|
||||
# Other models use single input
|
||||
input_name = self.session.get_inputs()[0].name
|
||||
input_data = {input_name: tensor}
|
||||
|
||||
# Run inference
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
t_start = time.perf_counter()
|
||||
|
||||
outputs = self.session.run(None, input_data)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
t_after_onnx = time.perf_counter()
|
||||
|
||||
if model_type == "yolo-generic" or model_type == "yologeneric":
|
||||
result = post_process_yolo(outputs, width, height)
|
||||
elif model_type == "dfine":
|
||||
result = post_process_dfine(outputs, width, height)
|
||||
elif model_type == "rfdetr":
|
||||
result = post_process_rfdetr(outputs)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
t_after_post = time.perf_counter()
|
||||
onnx_ms = (t_after_onnx - t_start) * 1000.0
|
||||
post_ms = (t_after_post - t_after_onnx) * 1000.0
|
||||
total_ms = (t_after_post - t_start) * 1000.0
|
||||
logger.debug(
|
||||
f"Inference timing: onnx={onnx_ms:.2f}ms, post={post_ms:.2f}ms, total={total_ms:.2f}ms"
|
||||
)
|
||||
|
||||
# Ensure float32 dtype
|
||||
result = result.astype(np.float32)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ONNX inference failed: {e}")
|
||||
return self.zero_result
|
||||
|
||||
def _extract_input_hw(self, header: dict) -> Tuple[int, int]:
|
||||
"""
|
||||
Extract (width, height) from the header and/or tensor shape, supporting
|
||||
NHWC/NCHW as well as 3D/4D inputs. Falls back to 320x320 if unknown.
|
||||
|
||||
Preference order:
|
||||
1) Explicit header keys: width/height
|
||||
2) Use provided layout to interpret shape
|
||||
3) Heuristics on shape
|
||||
"""
|
||||
try:
|
||||
if "width" in header and "height" in header:
|
||||
return int(header["width"]), int(header["height"])
|
||||
|
||||
shape = tuple(header.get("shape", []))
|
||||
layout = header.get("layout") or header.get("order")
|
||||
|
||||
if layout and shape:
|
||||
layout = str(layout).upper()
|
||||
if len(shape) == 4:
|
||||
if layout == "NCHW":
|
||||
return int(shape[3]), int(shape[2])
|
||||
if layout == "NHWC":
|
||||
return int(shape[2]), int(shape[1])
|
||||
if len(shape) == 3:
|
||||
if layout == "CHW":
|
||||
return int(shape[2]), int(shape[1])
|
||||
if layout == "HWC":
|
||||
return int(shape[1]), int(shape[0])
|
||||
|
||||
if shape:
|
||||
if len(shape) == 4:
|
||||
_, d1, d2, d3 = shape
|
||||
if d1 in (1, 3):
|
||||
return int(d3), int(d2)
|
||||
if d3 in (1, 3):
|
||||
return int(d2), int(d1)
|
||||
return int(d2), int(d1)
|
||||
if len(shape) == 3:
|
||||
d0, d1, d2 = shape
|
||||
if d0 in (1, 3):
|
||||
return int(d2), int(d1)
|
||||
if d2 in (1, 3):
|
||||
return int(d1), int(d0)
|
||||
return int(d1), int(d0)
|
||||
if len(shape) == 2:
|
||||
h, w = shape
|
||||
return int(w), int(h)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract input size from header: {e}")
|
||||
|
||||
logger.debug("Falling back to default input size (320x320)")
|
||||
return 320, 320
|
||||
|
||||
def _build_response(self, result: np.ndarray) -> List[bytes]:
|
||||
"""
|
||||
Build the response message.
|
||||
|
||||
Args:
|
||||
result: Detection results
|
||||
|
||||
Returns:
|
||||
List of response frames
|
||||
"""
|
||||
try:
|
||||
# Build header
|
||||
header = {
|
||||
"shape": list(result.shape),
|
||||
"dtype": str(result.dtype.name),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
header_bytes = json.dumps(header).encode("utf-8")
|
||||
|
||||
# Convert result to bytes
|
||||
result_bytes = result.tobytes(order="C")
|
||||
|
||||
return [header_bytes, result_bytes]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to build response: {e}")
|
||||
# Return zero result as fallback
|
||||
header = {
|
||||
"shape": [20, 6],
|
||||
"dtype": "float32",
|
||||
"error": "Failed to build response",
|
||||
}
|
||||
header_bytes = json.dumps(header).encode("utf-8")
|
||||
result_bytes = self.zero_result.tobytes(order="C")
|
||||
return [header_bytes, result_bytes]
|
||||
|
||||
def _handle_model_request(self, header: dict) -> List[bytes]:
|
||||
"""
|
||||
Handle model availability request.
|
||||
|
||||
Args:
|
||||
header: Request header containing model information
|
||||
|
||||
Returns:
|
||||
Response message indicating model availability
|
||||
"""
|
||||
model_name = header.get("model_name")
|
||||
|
||||
if not model_name:
|
||||
logger.error("Model request missing model_name")
|
||||
return self._build_error_response("Model request missing model_name")
|
||||
|
||||
logger.info(f"Model availability request for: {model_name}")
|
||||
|
||||
if self._check_model_exists(model_name):
|
||||
logger.info(f"Model {model_name} exists locally")
|
||||
# Try to load the model
|
||||
if self._load_model(model_name):
|
||||
response_header = {
|
||||
"model_available": True,
|
||||
"model_loaded": True,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} loaded successfully",
|
||||
}
|
||||
else:
|
||||
response_header = {
|
||||
"model_available": True,
|
||||
"model_loaded": False,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} exists but failed to load",
|
||||
}
|
||||
else:
|
||||
logger.info(f"Model {model_name} not found, requesting transfer")
|
||||
response_header = {
|
||||
"model_available": False,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} not found, please send model data",
|
||||
}
|
||||
|
||||
return [json.dumps(response_header).encode("utf-8")]
|
||||
|
||||
def _handle_model_data(self, header: dict, model_data: bytes) -> List[bytes]:
|
||||
"""
|
||||
Handle model data transfer.
|
||||
|
||||
Args:
|
||||
header: Request header containing model information
|
||||
model_data: Binary model data
|
||||
|
||||
Returns:
|
||||
Response message indicating save success/failure
|
||||
"""
|
||||
model_name = header.get("model_name")
|
||||
|
||||
if not model_name:
|
||||
logger.error("Model data missing model_name")
|
||||
return self._build_error_response("Model data missing model_name")
|
||||
|
||||
logger.info(f"Received model data for: {model_name}")
|
||||
|
||||
if self._save_model(model_name, model_data):
|
||||
# Try to load the model
|
||||
if self._load_model(model_name):
|
||||
response_header = {
|
||||
"model_saved": True,
|
||||
"model_loaded": True,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} saved and loaded successfully",
|
||||
}
|
||||
else:
|
||||
response_header = {
|
||||
"model_saved": True,
|
||||
"model_loaded": False,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} saved but failed to load",
|
||||
}
|
||||
else:
|
||||
response_header = {
|
||||
"model_saved": False,
|
||||
"model_loaded": False,
|
||||
"model_name": model_name,
|
||||
"message": f"Failed to save model {model_name}",
|
||||
}
|
||||
|
||||
return [json.dumps(response_header).encode("utf-8")]
|
||||
|
||||
def _build_error_response(self, error_msg: str) -> List[bytes]:
|
||||
"""Build an error response message."""
|
||||
error_header = {"error": error_msg}
|
||||
return [json.dumps(error_header).encode("utf-8")]
|
||||
|
||||
# Removed legacy single-thread start_server implementation in favor of worker pool
|
||||
|
||||
def _send_error_response(self, error_msg: str):
|
||||
"""Send an error response to the client."""
|
||||
try:
|
||||
error_header = {"shape": [20, 6], "dtype": "float32", "error": error_msg}
|
||||
error_response = [
|
||||
json.dumps(error_header).encode("utf-8"),
|
||||
self.zero_result.tobytes(order="C"),
|
||||
]
|
||||
self.socket.send_multipart(error_response)
|
||||
except Exception as send_error:
|
||||
logger.error(f"Failed to send error response: {send_error}")
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up resources."""
|
||||
try:
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
if self.context:
|
||||
self.context.term()
|
||||
self.context = None
|
||||
logger.info("Cleanup completed")
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup error: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run the ZMQ ONNX client."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="ZMQ TCP ONNX Runtime Client")
|
||||
parser.add_argument(
|
||||
"--endpoint",
|
||||
default="tcp://*:5555",
|
||||
help="ZMQ TCP endpoint (default: tcp://*:5555)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="AUTO",
|
||||
help="Path to ONNX model file or AUTO for automatic model management",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--providers",
|
||||
nargs="+",
|
||||
default=["CoreMLExecutionProvider"],
|
||||
help="ONNX Runtime execution providers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of REP worker threads",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", "-v", action="store_true", help="Enable verbose logging"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.verbose:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
# Create and start client
|
||||
client = ZmqOnnxClient(
|
||||
endpoint=args.endpoint, model_path=args.model, providers=args.providers
|
||||
)
|
||||
|
||||
try:
|
||||
client.start_server(num_workers=args.workers)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted by user")
|
||||
finally:
|
||||
client.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user