mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-10-04 15:13:22 +08:00
Compare commits
10 Commits
triggers-d
...
zmq-model-
Author | SHA1 | Date | |
---|---|---|---|
![]() |
7a02a448cb | ||
![]() |
4ab8de91a9 | ||
![]() |
fbcf64d7bd | ||
![]() |
de960285f6 | ||
![]() |
8b78c85bda | ||
![]() |
bdb7a18602 | ||
![]() |
318457113b | ||
![]() |
e4d5f1f94e | ||
![]() |
0e61d3f153 | ||
![]() |
cd519ed1ad |
@@ -85,13 +85,13 @@ semantic_search:
|
||||
enabled: True
|
||||
model_size: large
|
||||
# Optional, if using the 'large' model in a multi-GPU installation
|
||||
device: 0
|
||||
device: 0
|
||||
```
|
||||
|
||||
:::info
|
||||
|
||||
If the correct build is used for your GPU / NPU and the `large` model is configured, then the GPU / NPU will be detected and used automatically.
|
||||
Specify the `device` option to target a specific GPU in a multi-GPU system (see [onnxruntime's provider options](https://onnxruntime.ai/docs/execution-providers/)).
|
||||
If the correct build is used for your GPU / NPU and the `large` model is configured, then the GPU / NPU will be detected and used automatically.
|
||||
Specify the `device` option to target a specific GPU in a multi-GPU system (see [onnxruntime's provider options](https://onnxruntime.ai/docs/execution-providers/)).
|
||||
If you do not specify a device, the first available GPU will be used.
|
||||
|
||||
See the [Hardware Accelerated Enrichments](/configuration/hardware_acceleration_enrichments.md) documentation.
|
||||
@@ -144,3 +144,11 @@ When a trigger fires, the UI highlights the trigger with a blue outline for 3 se
|
||||
- Triggers rely on the same Jina AI CLIP models (V1 or V2) used for semantic search. Ensure `semantic_search` is enabled and properly configured.
|
||||
- Reindexing embeddings (via the UI or `reindex: True`) does not affect trigger configurations but may update the embeddings used for matching.
|
||||
- For optimal performance, use a system with sufficient RAM (8GB minimum, 16GB recommended) and a GPU for `large` model configurations, as described in the Semantic Search requirements.
|
||||
|
||||
### FAQ
|
||||
|
||||
#### Why can't I create a trigger on thumbnails for some text, like "person with a blue shirt" and have it trigger when a person with a blue shirt is detected?
|
||||
|
||||
TL;DR: Text-to-image triggers aren’t supported because CLIP can confuse similar images and give inconsistent scores, making automation unreliable.
|
||||
|
||||
Text-to-image triggers are not supported due to fundamental limitations of CLIP-based similarity search. While CLIP works well for exploratory, manual queries, it is unreliable for automated triggers based on a threshold. Issues include embedding drift (the same text–image pair can yield different cosine distances over time), lack of true semantic grounding (visually similar but incorrect matches), and unstable thresholding (distance distributions are dataset-dependent and often too tightly clustered to separate relevant from irrelevant results). Instead, it is recommended to set up a workflow with thumbnail triggers: first use text search to manually select 3–5 representative reference tracked objects, then configure thumbnail triggers based on that visual similarity. This provides robust automation without the semantic ambiguity of text to image matching.
|
||||
|
@@ -28,6 +28,9 @@ from frigate.comms.object_detector_signaler import DetectorProxy
|
||||
from frigate.comms.webpush import WebPushClient
|
||||
from frigate.comms.ws import WebSocketClient
|
||||
from frigate.comms.zmq_proxy import ZmqProxy
|
||||
from frigate.comms.zmq_req_router_broker import (
|
||||
ZmqReqRouterBroker,
|
||||
)
|
||||
from frigate.config.camera.updater import CameraConfigUpdatePublisher
|
||||
from frigate.config.config import FrigateConfig
|
||||
from frigate.const import (
|
||||
@@ -307,6 +310,14 @@ class FrigateApp:
|
||||
self.event_metadata_updater = EventMetadataPublisher()
|
||||
self.inter_zmq_proxy = ZmqProxy()
|
||||
self.detection_proxy = DetectorProxy()
|
||||
self.zmq_router_broker: ZmqReqRouterBroker | None = None
|
||||
|
||||
zmq_detectors = [
|
||||
det for det in self.config.detectors.values() if det.type == "zmq"
|
||||
]
|
||||
if any(zmq_detectors):
|
||||
backend_endpoint = zmq_detectors[0].endpoint
|
||||
self.zmq_router_broker = ZmqReqRouterBroker(backend_endpoint)
|
||||
|
||||
def init_onvif(self) -> None:
|
||||
self.onvif_controller = OnvifController(self.config, self.ptz_metrics)
|
||||
@@ -644,6 +655,9 @@ class FrigateApp:
|
||||
self.inter_zmq_proxy.stop()
|
||||
self.detection_proxy.stop()
|
||||
|
||||
if self.zmq_router_broker:
|
||||
self.zmq_router_broker.stop()
|
||||
|
||||
while len(self.detection_shms) > 0:
|
||||
shm = self.detection_shms.pop()
|
||||
shm.close()
|
||||
|
61
frigate/comms/zmq_req_router_broker.py
Normal file
61
frigate/comms/zmq_req_router_broker.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""ZMQ REQ/ROUTER front-end to DEALER/REP back-end broker.
|
||||
|
||||
This module provides a small proxy that:
|
||||
- Binds a ROUTER socket on a fixed local endpoint for REQ clients
|
||||
- Connects a DEALER socket to the user-configured backend endpoint (REP servers)
|
||||
|
||||
Pattern: REQ -> ROUTER === proxy === DEALER -> REP
|
||||
|
||||
The goal is to allow multiple REQ clients and/or multiple backend workers
|
||||
to share a single configured connection, enabling multiple models/runners
|
||||
behind the same broker while keeping local clients stable via constants.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
|
||||
import zmq
|
||||
|
||||
REQ_ROUTER_ENDPOINT = "ipc:///tmp/cache/zmq_detector_router"
|
||||
|
||||
|
||||
class _RouterDealerRunner(threading.Thread):
|
||||
def __init__(self, context: zmq.Context[zmq.Socket], backend_endpoint: str) -> None:
|
||||
super().__init__(name="zmq_router_dealer_broker", daemon=True)
|
||||
self.context = context
|
||||
self.backend_endpoint = backend_endpoint
|
||||
|
||||
def run(self) -> None:
|
||||
frontend = self.context.socket(zmq.ROUTER)
|
||||
frontend.bind(REQ_ROUTER_ENDPOINT)
|
||||
|
||||
backend = self.context.socket(zmq.DEALER)
|
||||
backend.bind(self.backend_endpoint)
|
||||
|
||||
try:
|
||||
zmq.proxy(frontend, backend)
|
||||
except zmq.ZMQError:
|
||||
# Unblocked when context is destroyed in the controller
|
||||
pass
|
||||
|
||||
|
||||
class ZmqReqRouterBroker:
|
||||
"""Starts a ROUTER/DEALER proxy bridging local REQ clients to backend REP.
|
||||
|
||||
- ROUTER binds to REQ_ROUTER_ENDPOINT (constant, local)
|
||||
- DEALER connects to the provided backend_endpoint (user-configured)
|
||||
"""
|
||||
|
||||
def __init__(self, backend_endpoint: str) -> None:
|
||||
self.backend_endpoint = backend_endpoint
|
||||
self.context = zmq.Context()
|
||||
self.runner = _RouterDealerRunner(self.context, backend_endpoint)
|
||||
self.runner.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
# Destroying the context signals the proxy to stop
|
||||
try:
|
||||
self.context.destroy()
|
||||
finally:
|
||||
self.runner.join()
|
@@ -1,13 +1,17 @@
|
||||
"""Base runner implementation for ONNX models."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import zmq
|
||||
|
||||
from frigate.comms.zmq_req_router_broker import REQ_ROUTER_ENDPOINT
|
||||
from frigate.util.model import get_ort_providers
|
||||
from frigate.util.rknn_converter import auto_convert_model, is_rknn_compatible
|
||||
|
||||
@@ -112,6 +116,7 @@ class CudaGraphRunner(BaseModelRunner):
|
||||
EnrichmentModelTypeEnum.paddleocr.value,
|
||||
EnrichmentModelTypeEnum.jina_v1.value,
|
||||
EnrichmentModelTypeEnum.jina_v2.value,
|
||||
EnrichmentModelTypeEnum.yolov9_license_plate.value,
|
||||
]
|
||||
|
||||
def __init__(self, session: ort.InferenceSession, cuda_device_id: int):
|
||||
@@ -194,6 +199,9 @@ class OpenVINOModelRunner(BaseModelRunner):
|
||||
# Apply performance optimization
|
||||
self.ov_core.set_property(device, {"PERF_COUNT": "NO"})
|
||||
|
||||
if device in ["GPU", "AUTO"]:
|
||||
self.ov_core.set_property(device, {"PERFORMANCE_HINT": "LATENCY"})
|
||||
|
||||
# Compile model
|
||||
self.compiled_model = self.ov_core.compile_model(
|
||||
model=model_path, device_name=device
|
||||
@@ -297,6 +305,187 @@ class OpenVINOModelRunner(BaseModelRunner):
|
||||
return outputs
|
||||
|
||||
|
||||
class ZmqIpcRunner(BaseModelRunner):
|
||||
"""Runner that forwards inference over ZMQ REQ/ROUTER to backend workers.
|
||||
|
||||
This allows reusing the same interface as local runners while delegating
|
||||
inference to the external ZMQ workers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
model_type: str,
|
||||
request_timeout_ms: int = 200,
|
||||
linger_ms: int = 0,
|
||||
endpoint: str = REQ_ROUTER_ENDPOINT,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.model_path = model_path
|
||||
self.model_name = os.path.basename(model_path)
|
||||
self._endpoint = endpoint
|
||||
self._context = zmq.Context()
|
||||
self._socket = self._context.socket(zmq.REQ)
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, request_timeout_ms)
|
||||
self._socket.setsockopt(zmq.SNDTIMEO, request_timeout_ms)
|
||||
self._socket.setsockopt(zmq.LINGER, linger_ms)
|
||||
self._socket.connect(self._endpoint)
|
||||
self._model_ready = False
|
||||
self._io_lock = threading.Lock()
|
||||
|
||||
@staticmethod
|
||||
def is_complex_model(model_type: str) -> bool:
|
||||
# Import here to avoid circular imports
|
||||
from frigate.detectors.detector_config import ModelTypeEnum
|
||||
from frigate.embeddings.types import EnrichmentModelTypeEnum
|
||||
|
||||
return model_type in [
|
||||
ModelTypeEnum.yolonas.value,
|
||||
EnrichmentModelTypeEnum.paddleocr.value,
|
||||
EnrichmentModelTypeEnum.jina_v1.value,
|
||||
EnrichmentModelTypeEnum.jina_v2.value,
|
||||
]
|
||||
|
||||
def get_input_names(self) -> list[str]:
|
||||
if "vision" in self.model_name:
|
||||
return ["pixel_values"]
|
||||
elif "arcface" in self.model_name:
|
||||
return ["data"]
|
||||
else:
|
||||
return ["input"]
|
||||
|
||||
def get_input_width(self) -> int:
|
||||
# Not known/required for ZMQ forwarding
|
||||
return -1
|
||||
|
||||
def _build_header(self, tensor_input: np.ndarray) -> bytes:
|
||||
header: dict[str, object] = {
|
||||
"shape": list(tensor_input.shape),
|
||||
"dtype": str(tensor_input.dtype.name),
|
||||
"model_type": self.model_type,
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
return json.dumps(header).encode("utf-8")
|
||||
|
||||
def _decode_response(self, frames: list[bytes]) -> np.ndarray:
|
||||
if len(frames) == 1:
|
||||
buf = frames[0]
|
||||
if len(buf) != 20 * 6 * 4:
|
||||
raise ValueError(f"Unexpected payload size: {len(buf)}")
|
||||
return np.frombuffer(buf, dtype=np.float32).reshape((20, 6))
|
||||
|
||||
if len(frames) >= 2:
|
||||
header = json.loads(frames[0].decode("utf-8"))
|
||||
shape = tuple(header.get("shape", []))
|
||||
dtype = np.dtype(header.get("dtype", "float32"))
|
||||
return np.frombuffer(frames[1], dtype=dtype).reshape(shape)
|
||||
|
||||
raise ValueError("Empty or malformed reply from ZMQ detector")
|
||||
|
||||
def run(self, input: dict[str, np.ndarray]) -> np.ndarray | None:
|
||||
if not self._model_ready:
|
||||
if not self.ensure_model_ready(self.model_path):
|
||||
raise TimeoutError("ZMQ detector model is not ready after transfer")
|
||||
self._model_ready = True
|
||||
|
||||
input_name = next(iter(input))
|
||||
tensor = input[input_name]
|
||||
header = self._build_header(tensor)
|
||||
payload = memoryview(tensor.tobytes(order="C"))
|
||||
try:
|
||||
with self._io_lock:
|
||||
self._socket.send_multipart([header, payload])
|
||||
frames = self._socket.recv_multipart()
|
||||
except zmq.Again as e:
|
||||
raise TimeoutError("ZMQ detector request timed out") from e
|
||||
except zmq.ZMQError as e:
|
||||
raise RuntimeError(f"ZMQ error: {e}") from e
|
||||
|
||||
return self._decode_response(frames)
|
||||
|
||||
def ensure_model_ready(self, model_path: str) -> bool:
|
||||
"""Ensure the remote has the model and it is loaded.
|
||||
|
||||
1) Send model_request with model_name
|
||||
2) If not available, send model_data with the file contents
|
||||
3) Wait for loaded confirmation
|
||||
Returns True on success.
|
||||
"""
|
||||
# Check model availability
|
||||
req = {"model_request": True, "model_name": self.model_name}
|
||||
with self._io_lock:
|
||||
self._socket.send_multipart([json.dumps(req).encode("utf-8")])
|
||||
|
||||
# Temporarily extend timeout for model ops
|
||||
original_rcv = self._socket.getsockopt(zmq.RCVTIMEO)
|
||||
try:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, max(30000, int(original_rcv or 0)))
|
||||
resp_frames = self._socket.recv_multipart()
|
||||
except zmq.Again:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv)
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
if len(resp_frames) != 1:
|
||||
return False
|
||||
resp = json.loads(resp_frames[0].decode("utf-8"))
|
||||
if resp.get("model_available") and resp.get("model_loaded"):
|
||||
logger.info(f"ZMQ detector model {self.model_name} is ready")
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(model_path, "rb") as f:
|
||||
model_bytes = f.read()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
header = {"model_data": True, "model_name": self.model_name}
|
||||
with self._io_lock:
|
||||
self._socket.send_multipart(
|
||||
[json.dumps(header).encode("utf-8"), model_bytes]
|
||||
)
|
||||
|
||||
original_rcv2 = self._socket.getsockopt(zmq.RCVTIMEO)
|
||||
try:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, max(30000, int(original_rcv2 or 0)))
|
||||
resp2 = self._socket.recv_multipart()
|
||||
except zmq.Again:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv2)
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
self._socket.setsockopt(zmq.RCVTIMEO, original_rcv2)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
if len(resp2) != 1:
|
||||
return False
|
||||
j = json.loads(resp2[0].decode("utf-8"))
|
||||
return bool(j.get("model_saved") and j.get("model_loaded"))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
if self._socket is not None:
|
||||
self._socket.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._context is not None:
|
||||
self._context.term()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class RKNNModelRunner(BaseModelRunner):
|
||||
"""Run RKNN models for embeddings."""
|
||||
|
||||
@@ -411,6 +600,10 @@ def get_optimized_runner(
|
||||
if rknn_path:
|
||||
return RKNNModelRunner(rknn_path)
|
||||
|
||||
if device == "ZMQ" and not ZmqIpcRunner.is_complex_model(model_type):
|
||||
logger.info(f"Using ZMQ detector model {model_path}")
|
||||
return ZmqIpcRunner(model_path, model_type, **kwargs)
|
||||
|
||||
providers, options = get_ort_providers(device == "CPU", device, **kwargs)
|
||||
|
||||
if providers[0] == "CPUExecutionProvider":
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
@@ -7,7 +8,9 @@ import zmq
|
||||
from pydantic import Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
from frigate.comms.zmq_req_router_broker import REQ_ROUTER_ENDPOINT
|
||||
from frigate.detectors.detection_api import DetectionApi
|
||||
from frigate.detectors.detection_runners import ZmqIpcRunner
|
||||
from frigate.detectors.detector_config import BaseDetectorConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -46,6 +49,9 @@ class ZmqIpcDetector(DetectionApi):
|
||||
b) Single frame tensor_bytes of length 20*6*4 bytes (float32).
|
||||
|
||||
On any error or timeout, this detector returns a zero array of shape (20, 6).
|
||||
|
||||
Model Management:
|
||||
- Model transfer/availability is handled by the runner automatically
|
||||
"""
|
||||
|
||||
type_key = DETECTOR_KEY
|
||||
@@ -54,14 +60,23 @@ class ZmqIpcDetector(DetectionApi):
|
||||
super().__init__(detector_config)
|
||||
|
||||
self._context = zmq.Context()
|
||||
self._endpoint = detector_config.endpoint
|
||||
self._endpoint = REQ_ROUTER_ENDPOINT
|
||||
self._request_timeout_ms = detector_config.request_timeout_ms
|
||||
self._linger_ms = detector_config.linger_ms
|
||||
self._socket = None
|
||||
self._create_socket()
|
||||
|
||||
self._model_name = self._get_model_name()
|
||||
|
||||
# Preallocate zero result for error paths
|
||||
self._zero_result = np.zeros((20, 6), np.float32)
|
||||
self._runner = ZmqIpcRunner(
|
||||
model_path=self.detector_config.model.path,
|
||||
model_type=str(self.detector_config.model.model_type.value),
|
||||
request_timeout_ms=self._request_timeout_ms,
|
||||
linger_ms=self._linger_ms,
|
||||
endpoint=self._endpoint,
|
||||
)
|
||||
|
||||
def _create_socket(self) -> None:
|
||||
if self._socket is not None:
|
||||
@@ -78,11 +93,17 @@ class ZmqIpcDetector(DetectionApi):
|
||||
logger.debug(f"ZMQ detector connecting to {self._endpoint}")
|
||||
self._socket.connect(self._endpoint)
|
||||
|
||||
def _get_model_name(self) -> str:
|
||||
"""Get the model filename from the detector config."""
|
||||
model_path = self.detector_config.model.path
|
||||
return os.path.basename(model_path)
|
||||
|
||||
def _build_header(self, tensor_input: np.ndarray) -> bytes:
|
||||
header: dict[str, Any] = {
|
||||
"shape": list(tensor_input.shape),
|
||||
"dtype": str(tensor_input.dtype.name),
|
||||
"model_type": str(self.detector_config.model.model_type.name),
|
||||
"model_name": self._model_name,
|
||||
}
|
||||
return json.dumps(header).encode("utf-8")
|
||||
|
||||
@@ -112,36 +133,10 @@ class ZmqIpcDetector(DetectionApi):
|
||||
|
||||
def detect_raw(self, tensor_input: np.ndarray) -> np.ndarray:
|
||||
try:
|
||||
header_bytes = self._build_header(tensor_input)
|
||||
payload_bytes = memoryview(tensor_input.tobytes(order="C"))
|
||||
|
||||
# Send request
|
||||
self._socket.send_multipart([header_bytes, payload_bytes])
|
||||
|
||||
# Receive reply
|
||||
reply_frames = self._socket.recv_multipart()
|
||||
detections = self._decode_response(reply_frames)
|
||||
|
||||
# Ensure output shape and dtype are exactly as expected
|
||||
|
||||
return detections
|
||||
except zmq.Again:
|
||||
# Timeout
|
||||
logger.debug("ZMQ detector request timed out; resetting socket")
|
||||
try:
|
||||
self._create_socket()
|
||||
except Exception:
|
||||
pass
|
||||
return self._zero_result
|
||||
except zmq.ZMQError as exc:
|
||||
logger.error(f"ZMQ detector ZMQError: {exc}; resetting socket")
|
||||
try:
|
||||
self._create_socket()
|
||||
except Exception:
|
||||
pass
|
||||
return self._zero_result
|
||||
result = self._runner.run({"input": tensor_input})
|
||||
return result if isinstance(result, np.ndarray) else self._zero_result
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error(f"ZMQ detector unexpected error: {exc}")
|
||||
logger.error(f"ZMQ IPC runner error: {exc}")
|
||||
return self._zero_result
|
||||
|
||||
def __del__(self) -> None: # pragma: no cover - best-effort cleanup
|
||||
|
838
frigate/detectors/zmq_client.py
Normal file
838
frigate/detectors/zmq_client.py
Normal file
@@ -0,0 +1,838 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
ZMQ TCP ONNX Runtime Client
|
||||
|
||||
This client connects to the ZMQ TCP proxy, accepts tensor inputs,
|
||||
runs inference via ONNX Runtime, and returns detection results.
|
||||
|
||||
Protocol:
|
||||
- Receives multipart messages: [header_json_bytes, tensor_bytes]
|
||||
- Header contains shape and dtype information
|
||||
- Runs ONNX inference on the tensor
|
||||
- Returns results in the expected format: [20, 6] float32 array
|
||||
|
||||
Note: Timeouts are normal when Frigate has no motion to detect.
|
||||
The server will continue running and waiting for requests.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import zmq
|
||||
from model_util import post_process_dfine, post_process_rfdetr, post_process_yolo
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ZmqOnnxWorker(threading.Thread):
|
||||
"""
|
||||
A worker thread that connects a REP socket to the endpoint and processes
|
||||
requests using a shared model session map. This mirrors the single-runner
|
||||
logic, but the ONNX Runtime session is fetched from the shared map, and
|
||||
created on-demand if missing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_id: int,
|
||||
context: zmq.Context,
|
||||
endpoint: str,
|
||||
models_dir: str,
|
||||
model_sessions: Dict[str, ort.InferenceSession],
|
||||
model_lock: threading.Lock,
|
||||
providers: Optional[List[str]],
|
||||
zero_result: np.ndarray,
|
||||
) -> None:
|
||||
super().__init__(name=f"onnx_worker_{worker_id}", daemon=True)
|
||||
self.worker_id = worker_id
|
||||
self.context = context
|
||||
self.endpoint = self._normalize_endpoint(endpoint)
|
||||
self.models_dir = models_dir
|
||||
self.model_sessions = model_sessions
|
||||
self.model_lock = model_lock
|
||||
self.providers = providers
|
||||
self.zero_result = zero_result
|
||||
self.socket: Optional[zmq.Socket] = None
|
||||
|
||||
def _normalize_endpoint(self, endpoint: str) -> str:
|
||||
if endpoint.startswith("tcp://*:"):
|
||||
port = endpoint.split(":", 2)[-1]
|
||||
return f"tcp://127.0.0.1:{port}"
|
||||
return endpoint
|
||||
|
||||
# --- ZMQ helpers ---
|
||||
def _create_socket(self) -> zmq.Socket:
|
||||
sock = self.context.socket(zmq.REP)
|
||||
sock.setsockopt(zmq.RCVTIMEO, 5000)
|
||||
sock.setsockopt(zmq.SNDTIMEO, 5000)
|
||||
sock.setsockopt(zmq.LINGER, 0)
|
||||
sock.connect(self.endpoint)
|
||||
return sock
|
||||
|
||||
def _decode_request(self, frames: List[bytes]) -> Tuple[Optional[np.ndarray], dict]:
|
||||
if len(frames) < 1:
|
||||
raise ValueError(f"Expected at least 1 frame, got {len(frames)}")
|
||||
|
||||
header_bytes = frames[0]
|
||||
header = json.loads(header_bytes.decode("utf-8"))
|
||||
|
||||
if "model_request" in header:
|
||||
return None, header
|
||||
if "model_data" in header:
|
||||
return None, header
|
||||
if len(frames) < 2:
|
||||
raise ValueError(f"Tensor request expected 2 frames, got {len(frames)}")
|
||||
|
||||
tensor_bytes = frames[1]
|
||||
shape = tuple(header.get("shape", []))
|
||||
dtype_str = header.get("dtype", "uint8")
|
||||
|
||||
dtype = np.dtype(dtype_str)
|
||||
tensor = np.frombuffer(tensor_bytes, dtype=dtype).reshape(shape)
|
||||
return tensor, header
|
||||
|
||||
def _build_response(self, result: np.ndarray) -> List[bytes]:
|
||||
header = {
|
||||
"shape": list(result.shape),
|
||||
"dtype": str(result.dtype.name),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
return [json.dumps(header).encode("utf-8"), result.tobytes(order="C")]
|
||||
|
||||
def _build_error_response(self, error_msg: str) -> List[bytes]:
|
||||
error_header = {"shape": [20, 6], "dtype": "float32", "error": error_msg}
|
||||
return [
|
||||
json.dumps(error_header).encode("utf-8"),
|
||||
self.zero_result.tobytes(order="C"),
|
||||
]
|
||||
|
||||
# --- Model/session helpers ---
|
||||
def _check_model_exists(self, model_name: str) -> bool:
|
||||
return os.path.exists(os.path.join(self.models_dir, model_name))
|
||||
|
||||
def _save_model(self, model_name: str, model_data: bytes) -> bool:
|
||||
try:
|
||||
os.makedirs(self.models_dir, exist_ok=True)
|
||||
with open(os.path.join(self.models_dir, model_name), "wb") as f:
|
||||
f.write(model_data)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Worker {self.worker_id} failed to save model {model_name}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
def _get_or_create_session(self, model_name: str) -> Optional[ort.InferenceSession]:
|
||||
with self.model_lock:
|
||||
session = self.model_sessions.get(model_name)
|
||||
if session is not None:
|
||||
return session
|
||||
try:
|
||||
providers = self.providers or ["CoreMLExecutionProvider"]
|
||||
session = ort.InferenceSession(
|
||||
os.path.join(self.models_dir, model_name), providers=providers
|
||||
)
|
||||
self.model_sessions[model_name] = session
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Worker {self.worker_id} failed to load model {model_name}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
# --- Inference helpers ---
|
||||
def _extract_input_hw(self, header: dict) -> Tuple[int, int]:
|
||||
try:
|
||||
if "width" in header and "height" in header:
|
||||
return int(header["width"]), int(header["height"])
|
||||
shape = tuple(header.get("shape", []))
|
||||
layout = header.get("layout") or header.get("order")
|
||||
if layout and shape:
|
||||
layout = str(layout).upper()
|
||||
if len(shape) == 4:
|
||||
if layout == "NCHW":
|
||||
return int(shape[3]), int(shape[2])
|
||||
if layout == "NHWC":
|
||||
return int(shape[2]), int(shape[1])
|
||||
if len(shape) == 3:
|
||||
if layout == "CHW":
|
||||
return int(shape[2]), int(shape[1])
|
||||
if layout == "HWC":
|
||||
return int(shape[1]), int(shape[0])
|
||||
if shape:
|
||||
if len(shape) == 4:
|
||||
_, d1, d2, d3 = shape
|
||||
if d1 in (1, 3):
|
||||
return int(d3), int(d2)
|
||||
if d3 in (1, 3):
|
||||
return int(d2), int(d1)
|
||||
return int(d2), int(d1)
|
||||
if len(shape) == 3:
|
||||
d0, d1, d2 = shape
|
||||
if d0 in (1, 3):
|
||||
return int(d2), int(d1)
|
||||
if d2 in (1, 3):
|
||||
return int(d1), int(d0)
|
||||
return int(d1), int(d0)
|
||||
if len(shape) == 2:
|
||||
h, w = shape
|
||||
return int(w), int(h)
|
||||
except Exception:
|
||||
pass
|
||||
return 320, 320
|
||||
|
||||
def _run_inference(
|
||||
self, session: ort.InferenceSession, tensor: np.ndarray, header: dict
|
||||
) -> np.ndarray:
|
||||
try:
|
||||
model_type = header.get("model_type")
|
||||
width, height = self._extract_input_hw(header)
|
||||
|
||||
if model_type == "dfine":
|
||||
input_data = {
|
||||
"images": tensor.astype(np.float32),
|
||||
"orig_target_sizes": np.array([[height, width]], dtype=np.int64),
|
||||
}
|
||||
else:
|
||||
input_name = session.get_inputs()[0].name
|
||||
input_data = {input_name: tensor}
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
t_start = time.perf_counter()
|
||||
|
||||
outputs = session.run(None, input_data)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
t_after_onnx = time.perf_counter()
|
||||
|
||||
if model_type == "yolo-generic" or model_type == "yologeneric":
|
||||
result = post_process_yolo(outputs, width, height)
|
||||
elif model_type == "dfine":
|
||||
result = post_process_dfine(outputs, width, height)
|
||||
elif model_type == "rfdetr":
|
||||
result = post_process_rfdetr(outputs)
|
||||
else:
|
||||
result = np.zeros((20, 6), dtype=np.float32)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
t_after_post = time.perf_counter()
|
||||
onnx_ms = (t_after_onnx - t_start) * 1000.0
|
||||
post_ms = (t_after_post - t_after_onnx) * 1000.0
|
||||
total_ms = (t_after_post - t_start) * 1000.0
|
||||
logger.debug(
|
||||
f"Worker {self.worker_id} timing: onnx={onnx_ms:.2f}ms, post={post_ms:.2f}ms, total={total_ms:.2f}ms"
|
||||
)
|
||||
|
||||
return result.astype(np.float32)
|
||||
except Exception as e:
|
||||
logger.error(f"Worker {self.worker_id} ONNX inference failed: {e}")
|
||||
return self.zero_result
|
||||
|
||||
# --- Message handlers ---
|
||||
def _handle_model_request(self, header: dict) -> List[bytes]:
|
||||
model_name = header.get("model_name")
|
||||
if not model_name:
|
||||
return self._build_error_response("Model request missing model_name")
|
||||
if self._check_model_exists(model_name):
|
||||
# Ensure session exists
|
||||
if self._get_or_create_session(model_name) is not None:
|
||||
response_header = {
|
||||
"model_available": True,
|
||||
"model_loaded": True,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} loaded successfully",
|
||||
}
|
||||
else:
|
||||
response_header = {
|
||||
"model_available": True,
|
||||
"model_loaded": False,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} exists but failed to load",
|
||||
}
|
||||
else:
|
||||
response_header = {
|
||||
"model_available": False,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} not found, please send model data",
|
||||
}
|
||||
return [json.dumps(response_header).encode("utf-8")]
|
||||
|
||||
def _handle_model_data(self, header: dict, model_data: bytes) -> List[bytes]:
|
||||
model_name = header.get("model_name")
|
||||
if not model_name:
|
||||
return self._build_error_response("Model data missing model_name")
|
||||
if self._save_model(model_name, model_data):
|
||||
# Ensure session is created
|
||||
if self._get_or_create_session(model_name) is not None:
|
||||
response_header = {
|
||||
"model_saved": True,
|
||||
"model_loaded": True,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} saved and loaded successfully",
|
||||
}
|
||||
else:
|
||||
response_header = {
|
||||
"model_saved": True,
|
||||
"model_loaded": False,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} saved but failed to load",
|
||||
}
|
||||
else:
|
||||
response_header = {
|
||||
"model_saved": False,
|
||||
"model_loaded": False,
|
||||
"model_name": model_name,
|
||||
"message": f"Failed to save model {model_name}",
|
||||
}
|
||||
return [json.dumps(response_header).encode("utf-8")]
|
||||
|
||||
# --- Thread run ---
|
||||
def run(self) -> None: # pragma: no cover - runtime loop
|
||||
try:
|
||||
self.socket = self._create_socket()
|
||||
logger.info(
|
||||
f"Worker {self.worker_id} connected REP to endpoint: {self.endpoint}"
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
frames = self.socket.recv_multipart()
|
||||
tensor, header = self._decode_request(frames)
|
||||
|
||||
if "model_request" in header:
|
||||
response = self._handle_model_request(header)
|
||||
self.socket.send_multipart(response)
|
||||
continue
|
||||
if "model_data" in header and len(frames) >= 2:
|
||||
model_data = frames[1]
|
||||
response = self._handle_model_data(header, model_data)
|
||||
self.socket.send_multipart(response)
|
||||
continue
|
||||
if tensor is not None:
|
||||
model_name = header.get("model_name")
|
||||
session = None
|
||||
if model_name:
|
||||
session = self._get_or_create_session(model_name)
|
||||
if session is None:
|
||||
result = self.zero_result
|
||||
else:
|
||||
result = self._run_inference(session, tensor, header)
|
||||
self.socket.send_multipart(self._build_response(result))
|
||||
continue
|
||||
|
||||
# Unknown message: reply with zeros
|
||||
self.socket.send_multipart(self._build_response(self.zero_result))
|
||||
except zmq.Again:
|
||||
continue
|
||||
except zmq.ZMQError as e:
|
||||
logger.error(f"Worker {self.worker_id} ZMQ error: {e}")
|
||||
# Recreate socket on transient errors
|
||||
try:
|
||||
if self.socket:
|
||||
self.socket.close(linger=0)
|
||||
finally:
|
||||
self.socket = self._create_socket()
|
||||
except Exception as e:
|
||||
logger.error(f"Worker {self.worker_id} unexpected error: {e}")
|
||||
try:
|
||||
self.socket.send_multipart(self._build_error_response(str(e)))
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
if self.socket:
|
||||
self.socket.close(linger=0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class ZmqOnnxClient:
|
||||
"""
|
||||
ZMQ TCP client that runs ONNX inference on received tensors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str = "tcp://*:5555",
|
||||
model_path: Optional[str] = "AUTO",
|
||||
providers: Optional[List[str]] = None,
|
||||
session_options: Optional[ort.SessionOptions] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the ZMQ ONNX client.
|
||||
|
||||
Args:
|
||||
endpoint: ZMQ IPC endpoint to bind to
|
||||
model_path: Path to ONNX model file or "AUTO" for automatic model management
|
||||
providers: ONNX Runtime execution providers
|
||||
session_options: ONNX Runtime session options
|
||||
"""
|
||||
self.endpoint = endpoint
|
||||
self.model_path = model_path
|
||||
self.current_model = None
|
||||
self.model_ready = False
|
||||
self.models_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "models"
|
||||
)
|
||||
|
||||
# Shared ZMQ context and shared session map across workers
|
||||
self.context = zmq.Context()
|
||||
self.model_sessions: Dict[str, ort.InferenceSession] = {}
|
||||
self.model_lock = threading.Lock()
|
||||
self.providers = providers
|
||||
|
||||
# Preallocate zero result for error cases
|
||||
self.zero_result = np.zeros((20, 6), dtype=np.float32)
|
||||
|
||||
logger.info(f"ZMQ ONNX client will start workers on endpoint: {endpoint}")
|
||||
|
||||
def start_server(self, num_workers: int = 4) -> None:
|
||||
workers: list[ZmqOnnxWorker] = []
|
||||
for i in range(num_workers):
|
||||
w = ZmqOnnxWorker(
|
||||
worker_id=i,
|
||||
context=self.context,
|
||||
endpoint=self.endpoint,
|
||||
models_dir=self.models_dir,
|
||||
model_sessions=self.model_sessions,
|
||||
model_lock=self.model_lock,
|
||||
providers=self.providers,
|
||||
zero_result=self.zero_result,
|
||||
)
|
||||
w.start()
|
||||
workers.append(w)
|
||||
logger.info(f"Started {num_workers} ZMQ REP workers on backend {self.endpoint}")
|
||||
try:
|
||||
for w in workers:
|
||||
w.join()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Shutting down workers...")
|
||||
|
||||
def _check_model_exists(self, model_name: str) -> bool:
|
||||
"""
|
||||
Check if a model exists in the models directory.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model file to check
|
||||
|
||||
Returns:
|
||||
True if model exists, False otherwise
|
||||
"""
|
||||
model_path = os.path.join(self.models_dir, model_name)
|
||||
return os.path.exists(model_path)
|
||||
|
||||
# These methods remain for compatibility but are unused in worker mode
|
||||
|
||||
def _save_model(self, model_name: str, model_data: bytes) -> bool:
|
||||
"""
|
||||
Save model data to the models directory.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model file to save
|
||||
model_data: Binary model data
|
||||
|
||||
Returns:
|
||||
True if model saved successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Ensure models directory exists
|
||||
os.makedirs(self.models_dir, exist_ok=True)
|
||||
|
||||
model_path = os.path.join(self.models_dir, model_name)
|
||||
logger.info(f"Saving model to: {model_path}")
|
||||
|
||||
with open(model_path, "wb") as f:
|
||||
f.write(model_data)
|
||||
|
||||
logger.info(f"Model saved successfully: {model_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save model {model_name}: {e}")
|
||||
return False
|
||||
|
||||
def _decode_request(self, frames: List[bytes]) -> Tuple[np.ndarray, dict]:
|
||||
"""
|
||||
Decode the incoming request frames.
|
||||
|
||||
Args:
|
||||
frames: List of message frames
|
||||
|
||||
Returns:
|
||||
Tuple of (tensor, header_dict)
|
||||
"""
|
||||
try:
|
||||
if len(frames) < 1:
|
||||
raise ValueError(f"Expected at least 1 frame, got {len(frames)}")
|
||||
|
||||
# Parse header
|
||||
header_bytes = frames[0]
|
||||
header = json.loads(header_bytes.decode("utf-8"))
|
||||
|
||||
if "model_request" in header:
|
||||
return None, header
|
||||
|
||||
if "model_data" in header:
|
||||
if len(frames) < 2:
|
||||
raise ValueError(
|
||||
f"Model data request expected 2 frames, got {len(frames)}"
|
||||
)
|
||||
return None, header
|
||||
|
||||
if len(frames) < 2:
|
||||
raise ValueError(f"Tensor request expected 2 frames, got {len(frames)}")
|
||||
|
||||
tensor_bytes = frames[1]
|
||||
shape = tuple(header.get("shape", []))
|
||||
dtype_str = header.get("dtype", "uint8")
|
||||
|
||||
dtype = np.dtype(dtype_str)
|
||||
tensor = np.frombuffer(tensor_bytes, dtype=dtype).reshape(shape)
|
||||
return tensor, header
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decode request: {e}")
|
||||
raise
|
||||
|
||||
def _run_inference(self, tensor: np.ndarray, header: dict) -> np.ndarray:
|
||||
"""
|
||||
Run ONNX inference on the input tensor.
|
||||
|
||||
Args:
|
||||
tensor: Input tensor
|
||||
header: Request header containing metadata (e.g., shape, layout)
|
||||
|
||||
Returns:
|
||||
Detection results as numpy array
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no ONNX session is available or inference fails
|
||||
"""
|
||||
if self.session is None:
|
||||
logger.warning("No ONNX session available, returning zero results")
|
||||
return self.zero_result
|
||||
|
||||
try:
|
||||
# Prepare input for ONNX Runtime
|
||||
# Determine input spatial size (W, H) from header/shape/layout
|
||||
model_type = header.get("model_type")
|
||||
width, height = self._extract_input_hw(header)
|
||||
|
||||
if model_type == "dfine":
|
||||
# DFine model requires both images and orig_target_sizes inputs
|
||||
input_data = {
|
||||
"images": tensor.astype(np.float32),
|
||||
"orig_target_sizes": np.array([[height, width]], dtype=np.int64),
|
||||
}
|
||||
else:
|
||||
# Other models use single input
|
||||
input_name = self.session.get_inputs()[0].name
|
||||
input_data = {input_name: tensor}
|
||||
|
||||
# Run inference
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
t_start = time.perf_counter()
|
||||
|
||||
outputs = self.session.run(None, input_data)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
t_after_onnx = time.perf_counter()
|
||||
|
||||
if model_type == "yolo-generic" or model_type == "yologeneric":
|
||||
result = post_process_yolo(outputs, width, height)
|
||||
elif model_type == "dfine":
|
||||
result = post_process_dfine(outputs, width, height)
|
||||
elif model_type == "rfdetr":
|
||||
result = post_process_rfdetr(outputs)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
t_after_post = time.perf_counter()
|
||||
onnx_ms = (t_after_onnx - t_start) * 1000.0
|
||||
post_ms = (t_after_post - t_after_onnx) * 1000.0
|
||||
total_ms = (t_after_post - t_start) * 1000.0
|
||||
logger.debug(
|
||||
f"Inference timing: onnx={onnx_ms:.2f}ms, post={post_ms:.2f}ms, total={total_ms:.2f}ms"
|
||||
)
|
||||
|
||||
# Ensure float32 dtype
|
||||
result = result.astype(np.float32)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ONNX inference failed: {e}")
|
||||
return self.zero_result
|
||||
|
||||
def _extract_input_hw(self, header: dict) -> Tuple[int, int]:
|
||||
"""
|
||||
Extract (width, height) from the header and/or tensor shape, supporting
|
||||
NHWC/NCHW as well as 3D/4D inputs. Falls back to 320x320 if unknown.
|
||||
|
||||
Preference order:
|
||||
1) Explicit header keys: width/height
|
||||
2) Use provided layout to interpret shape
|
||||
3) Heuristics on shape
|
||||
"""
|
||||
try:
|
||||
if "width" in header and "height" in header:
|
||||
return int(header["width"]), int(header["height"])
|
||||
|
||||
shape = tuple(header.get("shape", []))
|
||||
layout = header.get("layout") or header.get("order")
|
||||
|
||||
if layout and shape:
|
||||
layout = str(layout).upper()
|
||||
if len(shape) == 4:
|
||||
if layout == "NCHW":
|
||||
return int(shape[3]), int(shape[2])
|
||||
if layout == "NHWC":
|
||||
return int(shape[2]), int(shape[1])
|
||||
if len(shape) == 3:
|
||||
if layout == "CHW":
|
||||
return int(shape[2]), int(shape[1])
|
||||
if layout == "HWC":
|
||||
return int(shape[1]), int(shape[0])
|
||||
|
||||
if shape:
|
||||
if len(shape) == 4:
|
||||
_, d1, d2, d3 = shape
|
||||
if d1 in (1, 3):
|
||||
return int(d3), int(d2)
|
||||
if d3 in (1, 3):
|
||||
return int(d2), int(d1)
|
||||
return int(d2), int(d1)
|
||||
if len(shape) == 3:
|
||||
d0, d1, d2 = shape
|
||||
if d0 in (1, 3):
|
||||
return int(d2), int(d1)
|
||||
if d2 in (1, 3):
|
||||
return int(d1), int(d0)
|
||||
return int(d1), int(d0)
|
||||
if len(shape) == 2:
|
||||
h, w = shape
|
||||
return int(w), int(h)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract input size from header: {e}")
|
||||
|
||||
logger.debug("Falling back to default input size (320x320)")
|
||||
return 320, 320
|
||||
|
||||
def _build_response(self, result: np.ndarray) -> List[bytes]:
|
||||
"""
|
||||
Build the response message.
|
||||
|
||||
Args:
|
||||
result: Detection results
|
||||
|
||||
Returns:
|
||||
List of response frames
|
||||
"""
|
||||
try:
|
||||
# Build header
|
||||
header = {
|
||||
"shape": list(result.shape),
|
||||
"dtype": str(result.dtype.name),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
header_bytes = json.dumps(header).encode("utf-8")
|
||||
|
||||
# Convert result to bytes
|
||||
result_bytes = result.tobytes(order="C")
|
||||
|
||||
return [header_bytes, result_bytes]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to build response: {e}")
|
||||
# Return zero result as fallback
|
||||
header = {
|
||||
"shape": [20, 6],
|
||||
"dtype": "float32",
|
||||
"error": "Failed to build response",
|
||||
}
|
||||
header_bytes = json.dumps(header).encode("utf-8")
|
||||
result_bytes = self.zero_result.tobytes(order="C")
|
||||
return [header_bytes, result_bytes]
|
||||
|
||||
def _handle_model_request(self, header: dict) -> List[bytes]:
|
||||
"""
|
||||
Handle model availability request.
|
||||
|
||||
Args:
|
||||
header: Request header containing model information
|
||||
|
||||
Returns:
|
||||
Response message indicating model availability
|
||||
"""
|
||||
model_name = header.get("model_name")
|
||||
|
||||
if not model_name:
|
||||
logger.error("Model request missing model_name")
|
||||
return self._build_error_response("Model request missing model_name")
|
||||
|
||||
logger.info(f"Model availability request for: {model_name}")
|
||||
|
||||
if self._check_model_exists(model_name):
|
||||
logger.info(f"Model {model_name} exists locally")
|
||||
# Try to load the model
|
||||
if self._load_model(model_name):
|
||||
response_header = {
|
||||
"model_available": True,
|
||||
"model_loaded": True,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} loaded successfully",
|
||||
}
|
||||
else:
|
||||
response_header = {
|
||||
"model_available": True,
|
||||
"model_loaded": False,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} exists but failed to load",
|
||||
}
|
||||
else:
|
||||
logger.info(f"Model {model_name} not found, requesting transfer")
|
||||
response_header = {
|
||||
"model_available": False,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} not found, please send model data",
|
||||
}
|
||||
|
||||
return [json.dumps(response_header).encode("utf-8")]
|
||||
|
||||
def _handle_model_data(self, header: dict, model_data: bytes) -> List[bytes]:
|
||||
"""
|
||||
Handle model data transfer.
|
||||
|
||||
Args:
|
||||
header: Request header containing model information
|
||||
model_data: Binary model data
|
||||
|
||||
Returns:
|
||||
Response message indicating save success/failure
|
||||
"""
|
||||
model_name = header.get("model_name")
|
||||
|
||||
if not model_name:
|
||||
logger.error("Model data missing model_name")
|
||||
return self._build_error_response("Model data missing model_name")
|
||||
|
||||
logger.info(f"Received model data for: {model_name}")
|
||||
|
||||
if self._save_model(model_name, model_data):
|
||||
# Try to load the model
|
||||
if self._load_model(model_name):
|
||||
response_header = {
|
||||
"model_saved": True,
|
||||
"model_loaded": True,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} saved and loaded successfully",
|
||||
}
|
||||
else:
|
||||
response_header = {
|
||||
"model_saved": True,
|
||||
"model_loaded": False,
|
||||
"model_name": model_name,
|
||||
"message": f"Model {model_name} saved but failed to load",
|
||||
}
|
||||
else:
|
||||
response_header = {
|
||||
"model_saved": False,
|
||||
"model_loaded": False,
|
||||
"model_name": model_name,
|
||||
"message": f"Failed to save model {model_name}",
|
||||
}
|
||||
|
||||
return [json.dumps(response_header).encode("utf-8")]
|
||||
|
||||
def _build_error_response(self, error_msg: str) -> List[bytes]:
|
||||
"""Build an error response message."""
|
||||
error_header = {"error": error_msg}
|
||||
return [json.dumps(error_header).encode("utf-8")]
|
||||
|
||||
# Removed legacy single-thread start_server implementation in favor of worker pool
|
||||
|
||||
def _send_error_response(self, error_msg: str):
|
||||
"""Send an error response to the client."""
|
||||
try:
|
||||
error_header = {"shape": [20, 6], "dtype": "float32", "error": error_msg}
|
||||
error_response = [
|
||||
json.dumps(error_header).encode("utf-8"),
|
||||
self.zero_result.tobytes(order="C"),
|
||||
]
|
||||
self.socket.send_multipart(error_response)
|
||||
except Exception as send_error:
|
||||
logger.error(f"Failed to send error response: {send_error}")
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up resources."""
|
||||
try:
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
if self.context:
|
||||
self.context.term()
|
||||
self.context = None
|
||||
logger.info("Cleanup completed")
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup error: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run the ZMQ ONNX client."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="ZMQ TCP ONNX Runtime Client")
|
||||
parser.add_argument(
|
||||
"--endpoint",
|
||||
default="tcp://*:5555",
|
||||
help="ZMQ TCP endpoint (default: tcp://*:5555)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="AUTO",
|
||||
help="Path to ONNX model file or AUTO for automatic model management",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--providers",
|
||||
nargs="+",
|
||||
default=["CoreMLExecutionProvider"],
|
||||
help="ONNX Runtime execution providers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of REP worker threads",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", "-v", action="store_true", help="Enable verbose logging"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.verbose:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
# Create and start client
|
||||
client = ZmqOnnxClient(
|
||||
endpoint=args.endpoint, model_path=args.model, providers=args.providers
|
||||
)
|
||||
|
||||
try:
|
||||
client.start_server(num_workers=args.workers)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted by user")
|
||||
finally:
|
||||
client.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -262,7 +262,7 @@ class LicensePlateDetector(BaseEmbedding):
|
||||
self.runner = get_optimized_runner(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.device,
|
||||
model_type="yolov9",
|
||||
model_type=EnrichmentModelTypeEnum.yolov9_license_plate.value,
|
||||
)
|
||||
|
||||
def _preprocess_inputs(self, raw_inputs):
|
||||
|
@@ -12,3 +12,4 @@ class EnrichmentModelTypeEnum(str, Enum):
|
||||
jina_v1 = "jina_v1"
|
||||
jina_v2 = "jina_v2"
|
||||
paddleocr = "paddleocr"
|
||||
yolov9_license_plate = "yolov9_license_plate"
|
||||
|
@@ -26,6 +26,15 @@ import {
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { FrigateConfig } from "@/types/frigateConfig";
|
||||
import { CameraNameLabel } from "../camera/CameraNameLabel";
|
||||
import { isDesktop, isMobile } from "react-device-detect";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
MobilePage,
|
||||
MobilePageContent,
|
||||
MobilePageDescription,
|
||||
MobilePageHeader,
|
||||
MobilePageTitle,
|
||||
} from "../mobile/MobilePage";
|
||||
|
||||
type CreateRoleOverlayProps = {
|
||||
show: boolean;
|
||||
@@ -100,15 +109,27 @@ export default function CreateRoleDialog({
|
||||
onCancel();
|
||||
};
|
||||
|
||||
const Overlay = isDesktop ? Dialog : MobilePage;
|
||||
const Content = isDesktop ? DialogContent : MobilePageContent;
|
||||
const Header = isDesktop ? DialogHeader : MobilePageHeader;
|
||||
const Description = isDesktop ? DialogDescription : MobilePageDescription;
|
||||
const Title = isDesktop ? DialogTitle : MobilePageTitle;
|
||||
|
||||
return (
|
||||
<Dialog open={show} onOpenChange={onCancel}>
|
||||
<DialogContent className="sm:max-w-[425px]">
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t("roles.dialog.createRole.title")}</DialogTitle>
|
||||
<DialogDescription>
|
||||
<Overlay open={show} onOpenChange={onCancel}>
|
||||
<Content
|
||||
className={cn(
|
||||
"scrollbar-container overflow-y-auto",
|
||||
isDesktop && "my-4 flex max-h-dvh flex-col sm:max-w-[425px]",
|
||||
isMobile && "px-4",
|
||||
)}
|
||||
>
|
||||
<Header className="mt-2" onClose={onCancel}>
|
||||
<Title>{t("roles.dialog.createRole.title")}</Title>
|
||||
<Description className={cn(!isDesktop && "sr-only")}>
|
||||
{t("roles.dialog.createRole.desc")}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
</Description>
|
||||
</Header>
|
||||
|
||||
<Form {...form}>
|
||||
<form
|
||||
@@ -222,7 +243,7 @@ export default function CreateRoleDialog({
|
||||
</DialogFooter>
|
||||
</form>
|
||||
</Form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</Content>
|
||||
</Overlay>
|
||||
);
|
||||
}
|
||||
|
@@ -38,6 +38,15 @@ import { Trigger, TriggerAction, TriggerType } from "@/types/trigger";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { Textarea } from "../ui/textarea";
|
||||
import { useCameraFriendlyName } from "@/hooks/use-camera-friendly-name";
|
||||
import { isDesktop, isMobile } from "react-device-detect";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
MobilePage,
|
||||
MobilePageContent,
|
||||
MobilePageDescription,
|
||||
MobilePageHeader,
|
||||
MobilePageTitle,
|
||||
} from "../mobile/MobilePage";
|
||||
|
||||
type CreateTriggerDialogProps = {
|
||||
show: boolean;
|
||||
@@ -164,18 +173,30 @@ export default function CreateTriggerDialog({
|
||||
|
||||
const cameraName = useCameraFriendlyName(selectedCamera);
|
||||
|
||||
const Overlay = isDesktop ? Dialog : MobilePage;
|
||||
const Content = isDesktop ? DialogContent : MobilePageContent;
|
||||
const Header = isDesktop ? DialogHeader : MobilePageHeader;
|
||||
const Description = isDesktop ? DialogDescription : MobilePageDescription;
|
||||
const Title = isDesktop ? DialogTitle : MobilePageTitle;
|
||||
|
||||
return (
|
||||
<Dialog open={show} onOpenChange={onCancel}>
|
||||
<DialogContent className="sm:max-w-[425px]">
|
||||
<DialogHeader>
|
||||
<DialogTitle>
|
||||
<Overlay open={show} onOpenChange={onCancel}>
|
||||
<Content
|
||||
className={cn(
|
||||
"scrollbar-container overflow-y-auto",
|
||||
isDesktop && "my-4 flex max-h-dvh flex-col",
|
||||
isMobile && "px-4",
|
||||
)}
|
||||
>
|
||||
<Header className="mt-2" onClose={onCancel}>
|
||||
<Title>
|
||||
{t(
|
||||
trigger
|
||||
? "triggers.dialog.editTrigger.title"
|
||||
: "triggers.dialog.createTrigger.title",
|
||||
)}
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
</Title>
|
||||
<Description className={cn(!isDesktop && "sr-only")}>
|
||||
{t(
|
||||
trigger
|
||||
? "triggers.dialog.editTrigger.desc"
|
||||
@@ -184,8 +205,8 @@ export default function CreateTriggerDialog({
|
||||
camera: cameraName,
|
||||
},
|
||||
)}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
</Description>
|
||||
</Header>
|
||||
|
||||
<Form {...form}>
|
||||
<form
|
||||
@@ -415,7 +436,7 @@ export default function CreateTriggerDialog({
|
||||
</DialogFooter>
|
||||
</form>
|
||||
</Form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</Content>
|
||||
</Overlay>
|
||||
);
|
||||
}
|
||||
|
@@ -33,6 +33,15 @@ import {
|
||||
import { Shield, User } from "lucide-react";
|
||||
import { LuCheck, LuX } from "react-icons/lu";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { isDesktop, isMobile } from "react-device-detect";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
MobilePage,
|
||||
MobilePageContent,
|
||||
MobilePageDescription,
|
||||
MobilePageHeader,
|
||||
MobilePageTitle,
|
||||
} from "../mobile/MobilePage";
|
||||
|
||||
type CreateUserOverlayProps = {
|
||||
show: boolean;
|
||||
@@ -110,15 +119,27 @@ export default function CreateUserDialog({
|
||||
onCancel();
|
||||
};
|
||||
|
||||
const Overlay = isDesktop ? Dialog : MobilePage;
|
||||
const Content = isDesktop ? DialogContent : MobilePageContent;
|
||||
const Header = isDesktop ? DialogHeader : MobilePageHeader;
|
||||
const Description = isDesktop ? DialogDescription : MobilePageDescription;
|
||||
const Title = isDesktop ? DialogTitle : MobilePageTitle;
|
||||
|
||||
return (
|
||||
<Dialog open={show} onOpenChange={onCancel}>
|
||||
<DialogContent className="sm:max-w-[425px]">
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t("users.dialog.createUser.title")}</DialogTitle>
|
||||
<DialogDescription>
|
||||
<Overlay open={show} onOpenChange={onCancel}>
|
||||
<Content
|
||||
className={cn(
|
||||
"scrollbar-container overflow-y-auto",
|
||||
isDesktop && "my-4 flex max-h-dvh flex-col sm:max-w-[425px]",
|
||||
isMobile && "px-4",
|
||||
)}
|
||||
>
|
||||
<Header className="mt-2" onClose={onCancel}>
|
||||
<Title>{t("users.dialog.createUser.title")}</Title>
|
||||
<Description className={cn(!isDesktop && "sr-only")}>
|
||||
{t("users.dialog.createUser.desc")}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
</Description>
|
||||
</Header>
|
||||
|
||||
<Form {...form}>
|
||||
<form
|
||||
@@ -286,7 +307,7 @@ export default function CreateUserDialog({
|
||||
</DialogFooter>
|
||||
</form>
|
||||
</Form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</Content>
|
||||
</Overlay>
|
||||
);
|
||||
}
|
||||
|
Reference in New Issue
Block a user