mirror of
https://github.com/azaion/annotations.git
synced 2026-04-22 22:36:31 +00:00
add ramdisk, load AI model to ramdisk and start recognition from it
rewrite zmq to DEALER and ROUTER add GET_USER command to get CurrentUser from Python all auth is on the python side inference run and validate annotations on python
This commit is contained in:
+1
-1
@@ -50,7 +50,7 @@ This is crucial for the build because build needs Python.h header and other file
|
||||
pip install ultralytics
|
||||
|
||||
pip uninstall -y opencv-python
|
||||
pip install opencv-python cython msgpack cryptography rstream pika zmq
|
||||
pip install opencv-python cython msgpack cryptography rstream pika zmq pyjwt
|
||||
```
|
||||
In case of fbgemm.dll error (Windows specific):
|
||||
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
cdef class AIRecognitionConfig:
|
||||
cdef public double frame_recognition_seconds
|
||||
cdef public double tracking_distance_confidence
|
||||
cdef public double tracking_probability_increase
|
||||
cdef public double tracking_intersection_threshold
|
||||
cdef public int frame_period_recognition
|
||||
cdef public bytes file_data
|
||||
|
||||
@staticmethod
|
||||
cdef from_msgpack(bytes data)
|
||||
@@ -0,0 +1,32 @@
|
||||
from msgpack import unpackb
|
||||
|
||||
cdef class AIRecognitionConfig:
|
||||
def __init__(self,
|
||||
frame_recognition_seconds,
|
||||
tracking_distance_confidence,
|
||||
tracking_probability_increase,
|
||||
tracking_intersection_threshold,
|
||||
frame_period_recognition,
|
||||
file_data
|
||||
):
|
||||
self.frame_recognition_seconds = frame_recognition_seconds
|
||||
self.tracking_distance_confidence = tracking_distance_confidence
|
||||
self.tracking_probability_increase = tracking_probability_increase
|
||||
self.tracking_intersection_threshold = tracking_intersection_threshold
|
||||
self.frame_period_recognition = frame_period_recognition
|
||||
self.file_data = file_data
|
||||
|
||||
def __str__(self):
|
||||
return (f'frame_seconds : {self.frame_recognition_seconds}, distance_confidence : {self.tracking_distance_confidence}, '
|
||||
f'probability_increase : {self.tracking_probability_increase}, intersection_threshold : {self.tracking_intersection_threshold}, frame_period_recognition : {self.frame_period_recognition}')
|
||||
|
||||
@staticmethod
|
||||
cdef from_msgpack(bytes data):
|
||||
unpacked = unpackb(data, strict_map_key=False)
|
||||
return AIRecognitionConfig(
|
||||
unpacked.get("FrameRecognitionSeconds", 0.0),
|
||||
unpacked.get("TrackingDistanceConfidence", 0.0),
|
||||
unpacked.get("TrackingProbabilityIncrease", 0.0),
|
||||
unpacked.get("TrackingIntersectionThreshold", 0.0),
|
||||
unpacked.get("FramePeriodRecognition", 0),
|
||||
unpacked.get("Data", b''))
|
||||
@@ -1,8 +1,10 @@
|
||||
cdef class Detection:
|
||||
cdef double x, y, w, h
|
||||
cdef int cls
|
||||
cdef public double x, y, w, h, confidence
|
||||
cdef public int cls
|
||||
|
||||
cdef class Annotation:
|
||||
cdef bytes image
|
||||
cdef float time
|
||||
cdef list[Detection] detections
|
||||
cdef long time
|
||||
cdef public list[Detection] detections
|
||||
cdef bytes serialize(self)
|
||||
|
||||
|
||||
@@ -1,13 +1,35 @@
|
||||
import msgpack
|
||||
|
||||
cdef class Detection:
|
||||
def __init__(self, double x, double y, double w, double h, int cls):
|
||||
def __init__(self, double x, double y, double w, double h, int cls, double confidence):
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.w = w
|
||||
self.h = h
|
||||
self.cls = cls
|
||||
self.confidence = confidence
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.cls}: {self.x:.2f} {self.y:.2f} {self.w:.2f} {self.h:.2f}, prob: {(self.confidence*100):.1f}%'
|
||||
|
||||
cdef class Annotation:
|
||||
def __init__(self, bytes image_bytes, float time, list[Detection] detections):
|
||||
def __init__(self, bytes image_bytes, long time, list[Detection] detections):
|
||||
self.image = image_bytes
|
||||
self.time = time
|
||||
self.detections = detections
|
||||
self.detections = detections if detections is not None else []
|
||||
|
||||
cdef bytes serialize(self):
|
||||
return msgpack.packb({
|
||||
"i": self.image, # "i" = image
|
||||
"t": self.time, # "t" = time
|
||||
"d": [ # "d" = detections
|
||||
{
|
||||
"x": det.x,
|
||||
"y": det.y,
|
||||
"w": det.w,
|
||||
"h": det.h,
|
||||
"c": det.cls,
|
||||
"p": det.confidence
|
||||
} for det in self.detections
|
||||
]
|
||||
})
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
from user cimport User
|
||||
|
||||
cdef class ApiClient:
|
||||
cdef str email, password, token, folder, token_file, api_url
|
||||
cdef User user
|
||||
|
||||
cdef get_encryption_key(self, str hardware_hash)
|
||||
cdef login(self, str email, str password)
|
||||
cdef login(self)
|
||||
cdef set_token(self, str token)
|
||||
cdef get_user(self)
|
||||
|
||||
cdef load_bytes(self, str filename)
|
||||
cdef load_ai_model(self)
|
||||
cdef load_queue_config(self)
|
||||
|
||||
+47
-11
@@ -1,13 +1,14 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from http import HTTPStatus
|
||||
|
||||
from uuid import UUID
|
||||
import jwt
|
||||
import requests
|
||||
cimport constants
|
||||
from hardware_service cimport HardwareService, HardwareInfo
|
||||
from security cimport Security
|
||||
from io import BytesIO
|
||||
from user cimport User, RoleEnum
|
||||
|
||||
cdef class ApiClient:
|
||||
"""Handles API authentication and downloading of the AI model."""
|
||||
@@ -15,9 +16,11 @@ cdef class ApiClient:
|
||||
self.email = email
|
||||
self.password = password
|
||||
self.folder = folder
|
||||
self.user = None
|
||||
|
||||
if os.path.exists(<str>constants.TOKEN_FILE):
|
||||
with open(<str>constants.TOKEN_FILE, "r") as file:
|
||||
self.token = file.read().strip()
|
||||
self.set_token(<str>file.read().strip())
|
||||
else:
|
||||
self.token = None
|
||||
|
||||
@@ -25,21 +28,52 @@ cdef class ApiClient:
|
||||
cdef str key = f'{self.email}-{self.password}-{hardware_hash}-#%@AzaionKey@%#---'
|
||||
return Security.calc_hash(key)
|
||||
|
||||
cdef login(self, str email, str password):
|
||||
response = requests.post(f"{constants.API_URL}/login", json={"email": email, "password": password})
|
||||
cdef login(self):
|
||||
response = requests.post(f"{constants.API_URL}/login",
|
||||
json={"email": self.email, "password": self.password})
|
||||
response.raise_for_status()
|
||||
self.token = response.json()["token"]
|
||||
|
||||
token = response.json()["token"]
|
||||
self.set_token(token)
|
||||
with open(<str>constants.TOKEN_FILE, 'w') as file:
|
||||
file.write(self.token)
|
||||
file.write(token)
|
||||
|
||||
cdef set_token(self, str token):
|
||||
self.token = token
|
||||
claims = jwt.decode(token, options={"verify_signature": False})
|
||||
|
||||
try:
|
||||
id = str(UUID(claims.get("nameid", "")))
|
||||
except ValueError:
|
||||
raise ValueError("Invalid GUID format in claims")
|
||||
|
||||
email = claims.get("unique_name", "")
|
||||
|
||||
role_str = claims.get("role", "")
|
||||
if role_str == "ApiAdmin":
|
||||
role = RoleEnum.ApiAdmin
|
||||
elif role_str == "Admin":
|
||||
role = RoleEnum.Admin
|
||||
elif role_str == "ResourceUploader":
|
||||
role = RoleEnum.ResourceUploader
|
||||
elif role_str == "Validator":
|
||||
role = RoleEnum.Validator
|
||||
elif role_str == "Operator":
|
||||
role = RoleEnum.Operator
|
||||
else:
|
||||
role = RoleEnum.NONE
|
||||
self.user = User(id, email, role)
|
||||
|
||||
cdef get_user(self):
|
||||
if self.user is None:
|
||||
self.login()
|
||||
return self.user
|
||||
|
||||
cdef load_bytes(self, str filename):
|
||||
hardware_service = HardwareService()
|
||||
cdef HardwareInfo hardware = hardware_service.get_hardware_info()
|
||||
|
||||
if self.token is None:
|
||||
self.login(self.email, self.password)
|
||||
self.login()
|
||||
|
||||
url = f"{constants.API_URL}/resources/get/{self.folder}"
|
||||
headers = {
|
||||
@@ -56,7 +90,7 @@ cdef class ApiClient:
|
||||
response = requests.post(url, data=payload, headers=headers, stream=True)
|
||||
|
||||
if response.status_code == HTTPStatus.UNAUTHORIZED or response.status_code == HTTPStatus.FORBIDDEN:
|
||||
self.login(self.email, self.password)
|
||||
self.login()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.token}",
|
||||
"Content-Type": "application/json"
|
||||
@@ -69,7 +103,9 @@ cdef class ApiClient:
|
||||
key = self.get_encryption_key(hardware.hash)
|
||||
|
||||
stream = BytesIO(response.raw.read())
|
||||
return Security.decrypt_to(stream, key)
|
||||
data = Security.decrypt_to(stream, key)
|
||||
print(f'loaded file: {filename}, {len(data)} bytes')
|
||||
return data
|
||||
|
||||
cdef load_ai_model(self):
|
||||
return self.load_bytes(constants.AI_MODEL_FILE)
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
cdef str SOCKET_HOST # Host for the socket server
|
||||
cdef int SOCKET_PORT # Port for the socket server
|
||||
cdef int SOCKET_BUFFER_SIZE # Buffer size for socket communication
|
||||
cdef int ZMQ_PORT = 5127 # Port for the zmq
|
||||
|
||||
cdef int QUEUE_MAXSIZE # Maximum size of the command queue
|
||||
cdef str COMMANDS_QUEUE # Name of the commands queue in rabbit
|
||||
@@ -10,3 +8,5 @@ cdef str API_URL # Base URL for the external API
|
||||
cdef str TOKEN_FILE # Name of the token file where temporary token would be stored
|
||||
cdef str QUEUE_CONFIG_FILENAME # queue config filename to load from api
|
||||
cdef str AI_MODEL_FILE # AI Model file
|
||||
|
||||
cdef bytes DONE_SIGNAL
|
||||
@@ -1,6 +1,4 @@
|
||||
cdef str SOCKET_HOST = "127.0.0.1" # Host for the socket server
|
||||
cdef int SOCKET_PORT = 9127 # Port for the socket server
|
||||
cdef int SOCKET_BUFFER_SIZE = 4096 # Buffer size for socket communication
|
||||
cdef int ZMQ_PORT = 5127 # Port for the zmq
|
||||
|
||||
cdef int QUEUE_MAXSIZE = 1000 # Maximum size of the command queue
|
||||
cdef str COMMANDS_QUEUE = "azaion-commands"
|
||||
@@ -10,3 +8,5 @@ cdef str API_URL = "https://api.azaion.com" # Base URL for the external API
|
||||
cdef str TOKEN_FILE = "token"
|
||||
cdef str QUEUE_CONFIG_FILENAME = "secured-config.json"
|
||||
cdef str AI_MODEL_FILE = "azaion.pt"
|
||||
|
||||
cdef bytes DONE_SIGNAL = b"DONE"
|
||||
@@ -10,5 +10,8 @@ def start_server():
|
||||
except Exception as e:
|
||||
processor.stop()
|
||||
|
||||
def on_annotation(self, cmd, annotation):
|
||||
print('on_annotation hit!')
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_server()
|
||||
@@ -0,0 +1,17 @@
|
||||
from remote_command cimport RemoteCommand
|
||||
from annotation cimport Annotation
|
||||
from ai_config cimport AIRecognitionConfig
|
||||
|
||||
cdef class Inference:
|
||||
cdef object model
|
||||
cdef object on_annotation
|
||||
cdef Annotation _previous_annotation
|
||||
cdef AIRecognitionConfig ai_config
|
||||
|
||||
cdef bint is_video(self, str filepath)
|
||||
cdef run_inference(self, RemoteCommand cmd, int batch_size=?)
|
||||
cdef _process_video(self, RemoteCommand cmd, int batch_size)
|
||||
cdef _process_image(self, RemoteCommand cmd)
|
||||
|
||||
cdef frame_to_annotation(self, long time, frame, boxes: object)
|
||||
cdef bint is_valid_annotation(self, Annotation annotation)
|
||||
+76
-17
@@ -1,30 +1,38 @@
|
||||
import ai_config
|
||||
import msgpack
|
||||
from ultralytics import YOLO
|
||||
import mimetypes
|
||||
import cv2
|
||||
from ultralytics.engine.results import Boxes
|
||||
from remote_command cimport RemoteCommand
|
||||
from annotation cimport Detection, Annotation
|
||||
from secure_model cimport SecureModelLoader
|
||||
from ai_config cimport AIRecognitionConfig
|
||||
|
||||
cdef class Inference:
|
||||
def __init__(self, model_bytes, on_annotations):
|
||||
self.model = YOLO(model_bytes)
|
||||
self.on_annotations = on_annotations
|
||||
def __init__(self, model_bytes, on_annotation):
|
||||
loader = SecureModelLoader()
|
||||
model_path = loader.load_model(model_bytes)
|
||||
self.model = YOLO(<str>model_path)
|
||||
self.on_annotation = on_annotation
|
||||
|
||||
cdef bint is_video(self, str filepath):
|
||||
mime_type, _ = mimetypes.guess_type(<str>filepath)
|
||||
return mime_type and mime_type.startswith("video")
|
||||
|
||||
cdef run_inference(self, RemoteCommand cmd, int batch_size=8, int frame_skip=4):
|
||||
cdef run_inference(self, RemoteCommand cmd, int batch_size=8):
|
||||
print('run inference..')
|
||||
|
||||
if self.is_video(cmd.filename):
|
||||
return self._process_video(cmd, batch_size, frame_skip)
|
||||
return self._process_video(cmd, batch_size)
|
||||
else:
|
||||
return self._process_image(cmd)
|
||||
|
||||
cdef _process_video(self, RemoteCommand cmd, int batch_size, int frame_skip):
|
||||
cdef _process_video(self, RemoteCommand cmd, int batch_size):
|
||||
frame_count = 0
|
||||
batch_frame = []
|
||||
annotations = []
|
||||
v_input = cv2.VideoCapture(<str>cmd.filename)
|
||||
self.ai_config = AIRecognitionConfig.from_msgpack(cmd.data)
|
||||
|
||||
while v_input.isOpened():
|
||||
ret, frame = v_input.read()
|
||||
@@ -33,7 +41,7 @@ cdef class Inference:
|
||||
break
|
||||
|
||||
frame_count += 1
|
||||
if frame_count % frame_skip == 0:
|
||||
if frame_count % self.ai_config.frame_period_recognition == 0:
|
||||
batch_frame.append((frame, ms))
|
||||
|
||||
if len(batch_frame) == batch_size:
|
||||
@@ -41,10 +49,11 @@ cdef class Inference:
|
||||
results = self.model.track(frames, persist=True)
|
||||
|
||||
for frame, res in zip(batch_frame, results):
|
||||
annotation = self.process_detections(int(frame[1]), frame[0], res.boxes)
|
||||
if len(annotation.detections) > 0:
|
||||
annotations.append(annotation)
|
||||
self.on_annotations(cmd, annotations)
|
||||
annotation = self.frame_to_annotation(int(frame[1]), frame[0], res.boxes)
|
||||
|
||||
if self.is_valid_annotation(<Annotation>annotation):
|
||||
self._previous_annotation = annotation
|
||||
self.on_annotation(cmd, annotation)
|
||||
batch_frame.clear()
|
||||
|
||||
v_input.release()
|
||||
@@ -52,15 +61,65 @@ cdef class Inference:
|
||||
cdef _process_image(self, RemoteCommand cmd):
|
||||
frame = cv2.imread(<str>cmd.filename)
|
||||
res = self.model.track(frame)
|
||||
annotation = self.process_detections(0, frame, res[0].boxes)
|
||||
self.on_annotations(cmd, [annotation])
|
||||
annotation = self.frame_to_annotation(0, frame, res[0].boxes)
|
||||
self.on_annotation(cmd, annotation)
|
||||
|
||||
cdef process_detections(self, float time, frame, boxes: Boxes):
|
||||
cdef frame_to_annotation(self, long time, frame, boxes: Boxes):
|
||||
detections = []
|
||||
for box in boxes:
|
||||
b = box.xywhn[0].cpu().numpy()
|
||||
cls = int(box.cls[0].cpu().numpy().item())
|
||||
detections.append(Detection(<double>b[0], <double>b[1], <double>b[2], <double>b[3], cls))
|
||||
_, encoded_image = cv2.imencode('.jpg', frame[0])
|
||||
confidence = box.conf[0].cpu().numpy().item()
|
||||
det = Detection(<double> b[0], <double> b[1], <double> b[2], <double> b[3], cls, confidence)
|
||||
detections.append(det)
|
||||
_, encoded_image = cv2.imencode('.jpg', frame)
|
||||
image_bytes = encoded_image.tobytes()
|
||||
return Annotation(image_bytes, time, detections)
|
||||
|
||||
cdef bint is_valid_annotation(self, Annotation annotation):
|
||||
# No detections, invalid
|
||||
if not annotation.detections:
|
||||
return False
|
||||
|
||||
# First valid annotation, always accept
|
||||
if self._previous_annotation is None:
|
||||
return True
|
||||
|
||||
# Enough time has passed since last annotation
|
||||
if annotation.time >= self._previous_annotation.time + <long>(self.ai_config.frame_recognition_seconds * 1000):
|
||||
return True
|
||||
|
||||
# More objects detected than before
|
||||
if len(annotation.detections) > len(self._previous_annotation.detections):
|
||||
return True
|
||||
|
||||
cdef:
|
||||
Detection current_det, prev_det
|
||||
double dx, dy, distance_sq, min_distance_sq
|
||||
Detection closest_det
|
||||
|
||||
# Check each detection against previous frame
|
||||
for current_det in annotation.detections:
|
||||
min_distance_sq = 1e18 # Initialize with large value
|
||||
closest_det = None
|
||||
|
||||
# Find closest detection in previous frame
|
||||
for prev_det in self._previous_annotation.detections:
|
||||
dx = current_det.x - prev_det.x
|
||||
dy = current_det.y - prev_det.y
|
||||
distance_sq = dx * dx + dy * dy
|
||||
|
||||
if distance_sq < min_distance_sq:
|
||||
min_distance_sq = distance_sq
|
||||
closest_det = prev_det
|
||||
|
||||
# Check if beyond tracking distance
|
||||
if min_distance_sq > self.ai_config.tracking_distance_confidence:
|
||||
return True
|
||||
|
||||
# Check probability increase
|
||||
if current_det.confidence >= closest_det.confidence + self.ai_config.tracking_probability_increase:
|
||||
return True
|
||||
|
||||
# No validation criteria met
|
||||
return False
|
||||
|
||||
+16
-13
@@ -1,12 +1,13 @@
|
||||
import traceback
|
||||
from queue import Queue
|
||||
cimport constants
|
||||
import msgpack
|
||||
|
||||
from api_client cimport ApiClient
|
||||
from annotation cimport Annotation
|
||||
from inference import Inference
|
||||
from inference cimport Inference
|
||||
from remote_command cimport RemoteCommand, CommandType
|
||||
from remote_command_handler cimport RemoteCommandHandler
|
||||
from user cimport User
|
||||
import argparse
|
||||
|
||||
cdef class ParsedArguments:
|
||||
@@ -36,11 +37,10 @@ cdef class CommandProcessor:
|
||||
while self.running:
|
||||
try:
|
||||
command = self.command_queue.get()
|
||||
print(f'command is : {command}')
|
||||
model = self.api_client.load_ai_model()
|
||||
Inference(model, self.on_annotations).run_inference(command)
|
||||
Inference(model, self.on_annotation).run_inference(command)
|
||||
except Exception as e:
|
||||
print(f"Error processing queue: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
cdef on_command(self, RemoteCommand command):
|
||||
try:
|
||||
@@ -48,17 +48,20 @@ cdef class CommandProcessor:
|
||||
self.command_queue.put(command)
|
||||
elif command.command_type == CommandType.LOAD:
|
||||
response = self.api_client.load_bytes(command.filename)
|
||||
print(f'loaded file: {command.filename}, {len(response)} bytes')
|
||||
self.remote_handler.send(response)
|
||||
print(f'{len(response)} bytes was sent.')
|
||||
|
||||
self.remote_handler.send(command.client_id, response)
|
||||
elif command.command_type == CommandType.GET_USER:
|
||||
self.get_user(command, self.api_client.get_user())
|
||||
else:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Error handling client: {e}")
|
||||
|
||||
cdef on_annotations(self, RemoteCommand cmd, annotations: [Annotation]):
|
||||
data = msgpack.packb(annotations)
|
||||
self.remote_handler.send(data)
|
||||
print(f'{len(data)} bytes was sent.')
|
||||
cdef get_user(self, RemoteCommand command, User user):
|
||||
self.remote_handler.send(command.client_id, user.serialize())
|
||||
|
||||
cdef on_annotation(self, RemoteCommand cmd, Annotation annotation):
|
||||
data = annotation.serialize()
|
||||
self.remote_handler.send(cmd.client_id, data)
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
cdef enum CommandType:
|
||||
INFERENCE = 1
|
||||
LOAD = 2
|
||||
GET_USER = 3
|
||||
|
||||
cdef class RemoteCommand:
|
||||
cdef public bytes client_id
|
||||
cdef CommandType command_type
|
||||
cdef str filename
|
||||
cdef bytes data
|
||||
|
||||
@@ -10,8 +10,10 @@ cdef class RemoteCommand:
|
||||
command_type_names = {
|
||||
1: "INFERENCE",
|
||||
2: "LOAD",
|
||||
3: "GET_USER"
|
||||
}
|
||||
return f'{command_type_names[self.command_type]}: {self.filename}'
|
||||
data_str = f'. Data: {len(self.data)} bytes' if self.data else ''
|
||||
return f'{command_type_names[self.command_type]}: {self.filename}{data_str}'
|
||||
|
||||
@staticmethod
|
||||
cdef from_msgpack(bytes data):
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
cdef class RemoteCommandHandler:
|
||||
cdef object _on_command
|
||||
cdef object _context
|
||||
cdef object _socket
|
||||
cdef object _router
|
||||
cdef object _dealer
|
||||
cdef object _shutdown_event
|
||||
cdef object _pull_socket
|
||||
cdef object _pull_thread
|
||||
cdef object _push_socket
|
||||
cdef object _push_queue
|
||||
cdef object _push_thread
|
||||
cdef object _on_command
|
||||
|
||||
cdef object _proxy_thread
|
||||
cdef object _workers
|
||||
|
||||
cdef start(self)
|
||||
cdef _pull_loop(self)
|
||||
cdef _push_loop(self)
|
||||
cdef send(self, bytes message_bytes)
|
||||
cdef _proxy_loop(self)
|
||||
cdef _worker_loop(self)
|
||||
cdef send(self, bytes client_id, bytes data)
|
||||
cdef close(self)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from queue import Queue
|
||||
|
||||
import zmq
|
||||
import json
|
||||
from threading import Thread, Event
|
||||
from remote_command cimport RemoteCommand
|
||||
cimport constants
|
||||
|
||||
cdef class RemoteCommandHandler:
|
||||
def __init__(self, object on_command):
|
||||
@@ -11,68 +9,58 @@ cdef class RemoteCommandHandler:
|
||||
self._context = zmq.Context.instance()
|
||||
self._shutdown_event = Event()
|
||||
|
||||
self._pull_socket = self._context.socket(zmq.PULL)
|
||||
self._pull_socket.setsockopt(zmq.LINGER, 0)
|
||||
self._pull_socket.bind("tcp://*:5127")
|
||||
self._pull_thread = Thread(target=self._pull_loop, daemon=True)
|
||||
self._router = self._context.socket(zmq.ROUTER)
|
||||
self._router.setsockopt(zmq.LINGER, 0)
|
||||
self._router.bind(f'tcp://*:{constants.ZMQ_PORT}')
|
||||
|
||||
self._push_queue = Queue()
|
||||
self._dealer = self._context.socket(zmq.DEALER)
|
||||
self._dealer.setsockopt(zmq.LINGER, 0)
|
||||
self._dealer.bind("inproc://backend")
|
||||
|
||||
self._push_socket = self._context.socket(zmq.PUSH)
|
||||
self._push_socket.setsockopt(zmq.LINGER, 0)
|
||||
self._push_socket.bind("tcp://*:5128")
|
||||
self._push_thread = Thread(target=self._push_loop, daemon=True)
|
||||
self._proxy_thread = Thread(target=self._proxy_loop, daemon=True)
|
||||
|
||||
self._workers = []
|
||||
for _ in range(4): # 4 worker threads
|
||||
worker = Thread(target=self._worker_loop, daemon=True)
|
||||
self._workers.append(worker)
|
||||
|
||||
cdef start(self):
|
||||
self._pull_thread.start()
|
||||
self._push_thread.start()
|
||||
self._proxy_thread.start()
|
||||
for worker in self._workers:
|
||||
worker.start()
|
||||
|
||||
cdef _pull_loop(self):
|
||||
while not self._shutdown_event.is_set():
|
||||
print('wait for the command...')
|
||||
message = self._pull_socket.recv()
|
||||
cmd = RemoteCommand.from_msgpack(<bytes>message)
|
||||
print(f'received: {cmd}')
|
||||
self._on_command(cmd)
|
||||
cdef _proxy_loop(self):
|
||||
zmq.proxy(self._router, self._dealer)
|
||||
|
||||
cdef _push_loop(self):
|
||||
cdef _worker_loop(self):
|
||||
worker_socket = self._context.socket(zmq.DEALER)
|
||||
worker_socket.setsockopt(zmq.LINGER, 0)
|
||||
worker_socket.connect("inproc://backend")
|
||||
poller = zmq.Poller()
|
||||
poller.register(worker_socket, zmq.POLLIN)
|
||||
print('started receiver loop...')
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
response = self._push_queue.get(timeout=1) # Timeout to check shutdown flag
|
||||
self._push_socket.send(response)
|
||||
except:
|
||||
continue
|
||||
socks = dict(poller.poll(500))
|
||||
if worker_socket in socks:
|
||||
client_id, message = worker_socket.recv_multipart()
|
||||
cmd = RemoteCommand.from_msgpack(<bytes> message)
|
||||
cmd.client_id = client_id
|
||||
print(f'Received [{cmd}] from the client {client_id}')
|
||||
self._on_command(cmd)
|
||||
except Exception as e:
|
||||
print(f"Worker error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
cdef send(self, bytes message_bytes):
|
||||
print(f'about to send {len(message_bytes)}')
|
||||
try:
|
||||
self._push_queue.put(message_bytes)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
cdef send(self, bytes client_id, bytes data):
|
||||
with self._context.socket(zmq.DEALER) as socket:
|
||||
socket.connect("inproc://backend")
|
||||
socket.send_multipart([client_id, data])
|
||||
print(f'{len(data)} bytes was sent to client {client_id}')
|
||||
|
||||
cdef close(self):
|
||||
self._shutdown_event.set()
|
||||
self._pull_socket.close()
|
||||
self._push_socket.close()
|
||||
self._context.term()
|
||||
|
||||
|
||||
cdef class QueueConfig:
|
||||
cdef str host,
|
||||
cdef int port, command_port
|
||||
cdef str producer_user, producer_pw, consumer_user, consumer_pw
|
||||
|
||||
@staticmethod
|
||||
cdef QueueConfig from_json(str json_string):
|
||||
s = str(json_string).strip()
|
||||
cdef dict config_dict = json.loads(s)["QueueConfig"]
|
||||
cdef QueueConfig config = QueueConfig()
|
||||
|
||||
config.host = config_dict["Host"]
|
||||
config.port = config_dict["Port"]
|
||||
config.command_port = config_dict["CommandsPort"]
|
||||
config.producer_user = config_dict["ProducerUsername"]
|
||||
config.producer_pw = config_dict["ProducerPassword"]
|
||||
config.consumer_user = config_dict["ConsumerUsername"]
|
||||
config.consumer_pw = config_dict["ConsumerPassword"]
|
||||
return config
|
||||
self._router.close()
|
||||
self._dealer.close()
|
||||
self._context.term()
|
||||
@@ -0,0 +1,12 @@
|
||||
cdef class SecureModelLoader:
|
||||
cdef:
|
||||
bytes _model_bytes
|
||||
str _ramdisk_path
|
||||
str _temp_file_path
|
||||
int _disk_size_mb
|
||||
|
||||
cpdef str load_model(self, bytes model_bytes)
|
||||
cdef str _get_ramdisk_path(self)
|
||||
cdef void _create_ramdisk(self)
|
||||
cdef void _store_model(self)
|
||||
cdef void _cleanup(self)
|
||||
@@ -0,0 +1,104 @@
|
||||
import os
|
||||
import platform
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from libc.stdio cimport FILE, fopen, fclose, remove
|
||||
from libc.stdlib cimport free
|
||||
from libc.string cimport strdup
|
||||
|
||||
cdef class SecureModelLoader:
|
||||
def __cinit__(self, int disk_size_mb=512):
|
||||
self._disk_size_mb = disk_size_mb
|
||||
self._ramdisk_path = None
|
||||
self._temp_file_path = None
|
||||
|
||||
cpdef str load_model(self, bytes model_bytes):
|
||||
"""Public method to load YOLO model securely."""
|
||||
self._model_bytes = model_bytes
|
||||
self._create_ramdisk()
|
||||
self._store_model()
|
||||
return self._temp_file_path
|
||||
|
||||
cdef str _get_ramdisk_path(self):
|
||||
"""Determine the RAM disk path based on the OS."""
|
||||
if platform.system() == "Windows":
|
||||
return "R:\\"
|
||||
elif platform.system() == "Linux":
|
||||
return "/mnt/ramdisk"
|
||||
elif platform.system() == "Darwin":
|
||||
return "/Volumes/RAMDisk"
|
||||
else:
|
||||
raise RuntimeError("Unsupported OS for RAM disk")
|
||||
|
||||
cdef void _create_ramdisk(self):
|
||||
"""Create a RAM disk securely based on the OS."""
|
||||
system = platform.system()
|
||||
|
||||
if system == "Windows":
|
||||
# Create RAM disk via PowerShell
|
||||
command = f'powershell -Command "subst R: {tempfile.gettempdir()}"'
|
||||
if os.system(command) != 0:
|
||||
raise RuntimeError("Failed to create RAM disk on Windows")
|
||||
self._ramdisk_path = "R:\\"
|
||||
|
||||
elif system == "Linux":
|
||||
# Use tmpfs for RAM disk
|
||||
self._ramdisk_path = "/mnt/ramdisk"
|
||||
if not Path(self._ramdisk_path).exists():
|
||||
os.mkdir(self._ramdisk_path)
|
||||
if os.system(f"mount -t tmpfs -o size={self._disk_size_mb}M tmpfs {self._ramdisk_path}") != 0:
|
||||
raise RuntimeError("Failed to create RAM disk on Linux")
|
||||
|
||||
elif system == "Darwin":
|
||||
# Use hdiutil for macOS RAM disk
|
||||
block_size = 2048 # 512-byte blocks * 2048 = 1MB
|
||||
num_blocks = self._disk_size_mb * block_size
|
||||
result = os.popen(f"hdiutil attach -nomount ram://{num_blocks}").read().strip()
|
||||
if result:
|
||||
self._ramdisk_path = "/Volumes/RAMDisk"
|
||||
os.system(f"diskutil eraseVolume HFS+ RAMDisk {result}")
|
||||
else:
|
||||
raise RuntimeError("Failed to create RAM disk on macOS")
|
||||
|
||||
cdef void _store_model(self):
|
||||
"""Write model securely to the RAM disk."""
|
||||
cdef char* temp_path
|
||||
cdef FILE* cfile
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
dir=self._ramdisk_path, suffix='.pt', delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(self._model_bytes)
|
||||
self._temp_file_path = tmp_file.name
|
||||
|
||||
encoded_path = self._temp_file_path.encode('utf-8')
|
||||
temp_path = strdup(encoded_path)
|
||||
with nogil:
|
||||
cfile = fopen(temp_path, "rb")
|
||||
if cfile == NULL:
|
||||
raise IOError(f"Could not open {self._temp_file_path}")
|
||||
fclose(cfile)
|
||||
|
||||
cdef void _cleanup(self):
|
||||
"""Remove the model file and unmount RAM disk securely."""
|
||||
cdef char* c_path
|
||||
if self._temp_file_path:
|
||||
c_path = strdup(os.fsencode(self._temp_file_path))
|
||||
with nogil:
|
||||
remove(c_path)
|
||||
free(c_path)
|
||||
self._temp_file_path = None
|
||||
|
||||
# Unmount RAM disk based on OS
|
||||
if self._ramdisk_path:
|
||||
if platform.system() == "Windows":
|
||||
os.system("subst R: /D")
|
||||
elif platform.system() == "Linux":
|
||||
os.system(f"umount {self._ramdisk_path}")
|
||||
elif platform.system() == "Darwin":
|
||||
os.system("hdiutil detach /Volumes/RAMDisk")
|
||||
self._ramdisk_path = None
|
||||
|
||||
def __dealloc__(self):
|
||||
"""Ensure cleanup when the object is deleted."""
|
||||
self._cleanup()
|
||||
+10
-1
@@ -8,7 +8,10 @@ extensions = [
|
||||
Extension('hardware_service', ['hardware_service.pyx'], extra_compile_args=["-g"], extra_link_args=["-g"]),
|
||||
Extension('remote_command', ['remote_command.pyx']),
|
||||
Extension('remote_command_handler', ['remote_command_handler.pyx']),
|
||||
Extension('user', ['user.pyx']),
|
||||
Extension('api_client', ['api_client.pyx']),
|
||||
Extension('secure_model', ['secure_model.pyx']),
|
||||
Extension('ai_config', ['ai_config.pyx']),
|
||||
Extension('inference', ['inference.pyx']),
|
||||
|
||||
Extension('main', ['main.pyx']),
|
||||
@@ -21,8 +24,14 @@ setup(
|
||||
compiler_directives={
|
||||
"language_level": 3,
|
||||
"emit_code_comments" : False,
|
||||
"binding": True
|
||||
"binding": True,
|
||||
'boundscheck': False,
|
||||
'wraparound': False
|
||||
}
|
||||
),
|
||||
install_requires=[
|
||||
'ultralytics>=8.0.0',
|
||||
'pywin32; platform_system=="Windows"'
|
||||
],
|
||||
zip_safe=False
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1laWQiOiJkOTBhMzZjYS1lMjM3LTRmYmQtOWM3Yy0xMjcwNDBhYzg1NTYiLCJ1bmlxdWVfbmFtZSI6ImFkbWluQGF6YWlvbi5jb20iLCJyb2xlIjoiQXBpQWRtaW4iLCJuYmYiOjE3MzgxNjM2MzYsImV4cCI6MTczODE3ODAzNiwiaWF0IjoxNzM4MTYzNjM2LCJpc3MiOiJBemFpb25BcGkiLCJhdWQiOiJBbm5vdGF0b3JzL09yYW5nZVBpL0FkbWlucyJ9.7VVws5mwGqx--sGopOuZE9iu3dzt1UdVPXeje2KZTYk
|
||||
@@ -0,0 +1,15 @@
|
||||
cdef enum RoleEnum:
|
||||
NONE = 0
|
||||
Operator = 10
|
||||
Validator = 20
|
||||
CompanionPC = 30
|
||||
Admin = 40
|
||||
ResourceUploader = 50
|
||||
ApiAdmin = 1000
|
||||
|
||||
cdef class User:
|
||||
cdef public str id
|
||||
cdef public str email
|
||||
cdef public RoleEnum role
|
||||
|
||||
cdef bytes serialize(self)
|
||||
@@ -0,0 +1,15 @@
|
||||
import msgpack
|
||||
|
||||
cdef class User:
|
||||
|
||||
def __init__(self, str id, str email, RoleEnum role):
|
||||
self.id = id
|
||||
self.email = email
|
||||
self.role = role
|
||||
|
||||
cdef bytes serialize(self):
|
||||
return msgpack.packb({
|
||||
"i": self.id,
|
||||
"e": self.email,
|
||||
"r": self.role
|
||||
})
|
||||
Reference in New Issue
Block a user