rewrite to zmq push and pull patterns.

file load works, suite can start up
This commit is contained in:
Alex Bezdieniezhnykh
2025-01-23 14:37:13 +02:00
parent ce25ef38b0
commit 82b3b526a7
20 changed files with 243 additions and 208 deletions
+3 -1
View File
@@ -35,6 +35,8 @@ Linux
<h3>Install dependencies</h3>
1. Install python with max version 3.11. Pytorch for now supports 3.11 max
Make sure that your virtual env is installed with links to the global python packages and headers, like this:
```
python -m venv --system-site-packages venv
@@ -48,7 +50,7 @@ This is crucial for the build because build needs Python.h header and other file
pip install ultralytics
pip uninstall -y opencv-python
pip install opencv-python cython msgpack cryptography rstream
pip install opencv-python cython msgpack cryptography rstream pika zmq
```
In case of fbgemm.dll error (Windows specific):
-2
View File
@@ -1,5 +1,3 @@
from io import BytesIO
import main
from main import ParsedArguments
+4 -6
View File
@@ -2,12 +2,10 @@ from ultralytics import YOLO
import mimetypes
import cv2
from ultralytics.engine.results import Boxes
from processor_command import FileCommand
from remote_command cimport RemoteCommand
from annotation cimport Detection, Annotation
cdef class Inference:
"""Handles YOLO inference using the AI model."""
def __init__(self, model_bytes, on_annotations):
self.model = YOLO(model_bytes)
self.on_annotations = on_annotations
@@ -16,13 +14,13 @@ cdef class Inference:
mime_type, _ = mimetypes.guess_type(<str>filepath)
return mime_type and mime_type.startswith("video")
cdef run_inference(self, cmd: FileCommand, int batch_size=8, int frame_skip=4):
cdef run_inference(self, RemoteCommand cmd, int batch_size=8, int frame_skip=4):
if self.is_video(cmd.filename):
return self._process_video(cmd, batch_size, frame_skip)
else:
return self._process_image(cmd)
cdef _process_video(self, cmd: FileCommand, int batch_size, int frame_skip):
cdef _process_video(self, RemoteCommand cmd, int batch_size, int frame_skip):
frame_count = 0
batch_frame = []
annotations = []
@@ -51,7 +49,7 @@ cdef class Inference:
v_input.release()
cdef _process_image(self, cmd: FileCommand):
cdef _process_image(self, RemoteCommand cmd):
frame = cv2.imread(<str>cmd.filename)
res = self.model.track(frame)
annotation = self.process_detections(0, frame, res[0].boxes)
+22 -26
View File
@@ -1,11 +1,12 @@
import queue
import threading
from queue import Queue
cimport constants
import msgpack
from api_client cimport ApiClient
from annotation cimport Annotation
from inference import Inference
from processor_command cimport FileCommand, CommandType, ProcessorType
from remote_handlers cimport SocketHandler, RabbitHandler
from remote_command cimport RemoteCommand, CommandType
from remote_command_handler cimport RemoteCommandHandler
import argparse
cdef class ParsedArguments:
@@ -20,18 +21,15 @@ cdef class ParsedArguments:
cdef class CommandProcessor:
cdef ApiClient api_client
cdef SocketHandler socket_handler
cdef RabbitHandler rabbit_handler
cdef RemoteCommandHandler remote_handler
cdef object command_queue
cdef bint running
def __init__(self, args: ParsedArguments):
self.api_client = ApiClient(args.email, args.password, args.folder)
self.socket_handler = SocketHandler(self.on_message)
self.socket_handler.start()
self.rabbit_handler = RabbitHandler(self.api_client, self.on_message)
self.rabbit_handler.start()
self.command_queue = queue.Queue(maxsize=constants.QUEUE_MAXSIZE)
self.remote_handler = RemoteCommandHandler(self.on_command)
self.command_queue = Queue(maxsize=constants.QUEUE_MAXSIZE)
self.remote_handler.start()
self.running = True
def start(self):
@@ -44,25 +42,23 @@ cdef class CommandProcessor:
except Exception as e:
print(f"Error processing queue: {e}")
cdef on_message(self, FileCommand cmd):
cdef on_command(self, RemoteCommand command):
try:
if cmd.command_type == CommandType.INFERENCE:
self.command_queue.put(cmd)
elif cmd.command_type == CommandType.LOAD:
threading.Thread(target=self.process_load, args=[cmd], daemon=True).start()
if command.command_type == CommandType.INFERENCE:
self.command_queue.put(command)
elif command.command_type == CommandType.LOAD:
response = self.api_client.load_bytes(command.filename)
print(f'loaded file: {command.filename}, {len(response)} bytes')
self.remote_handler.send(response)
print(f'{len(response)} bytes was sent.')
except Exception as e:
print(f"Error handling client: {e}")
cdef on_annotations(self, cmd: FileCommand, annotations: [Annotation]):
handler = self.socket_handler if cmd.processor_type == ProcessorType.SOCKET else self.rabbit_handler
handler.send(annotations)
cdef process_load(self, FileCommand command):
response = self.api_client.load_bytes(command.filename)
handler = self.socket_handler if command.processor_type == ProcessorType.SOCKET else self.rabbit_handler
handler.send(response)
cdef on_annotations(self, RemoteCommand cmd, annotations: [Annotation]):
data = msgpack.packb(annotations)
self.remote_handler.send(data)
print(f'{len(data)} bytes was sent.')
def stop(self):
self.running = False
-15
View File
@@ -1,15 +0,0 @@
cdef enum ProcessorType:
SOCKET = 1,
RABBIT = 2
cdef enum CommandType:
INFERENCE = 1
LOAD = 2
cdef class FileCommand:
cdef CommandType command_type
cdef ProcessorType processor_type
cdef str filename
@staticmethod
cdef from_msgpack(bytes data, ProcessorType processor_type)
-13
View File
@@ -1,13 +0,0 @@
import msgpack
cdef class FileCommand:
def __init__(self, command_type: CommandType, ProcessorType processor_type, str filename):
self.command_type = command_type
self.processor_type = processor_type
self.filename = filename
@staticmethod
cdef from_msgpack(bytes data, ProcessorType processor_type):
unpacked = msgpack.unpackb(data, strict_map_key=False)
return FileCommand(unpacked.get("CommandType"), processor_type, unpacked.get("Filename")
)
+11
View File
@@ -0,0 +1,11 @@
cdef enum CommandType:
INFERENCE = 1
LOAD = 2
cdef class RemoteCommand:
cdef CommandType command_type
cdef str filename
cdef bytes data
@staticmethod
cdef from_msgpack(bytes data)
+19
View File
@@ -0,0 +1,19 @@
import msgpack
cdef class RemoteCommand:
def __init__(self, CommandType command_type, str filename, bytes data):
self.command_type = command_type
self.filename = filename
self.data = data
def __str__(self):
command_type_names = {
1: "INFERENCE",
2: "LOAD",
}
return f'{command_type_names[self.command_type]}: {self.filename}'
@staticmethod
cdef from_msgpack(bytes data):
unpacked = msgpack.unpackb(data, strict_map_key=False)
return RemoteCommand(unpacked.get("CommandType"), unpacked.get("Filename"), unpacked.get("Data"))
+16
View File
@@ -0,0 +1,16 @@
cdef class RemoteCommandHandler:
cdef object _on_command
cdef object _context
cdef object _socket
cdef object _shutdown_event
cdef object _pull_socket
cdef object _pull_thread
cdef object _push_socket
cdef object _push_queue
cdef object _push_thread
cdef start(self)
cdef _pull_loop(self)
cdef _push_loop(self)
cdef send(self, bytes message_bytes)
cdef close(self)
+78
View File
@@ -0,0 +1,78 @@
from queue import Queue
import zmq
import json
from threading import Thread, Event
from remote_command cimport RemoteCommand
cdef class RemoteCommandHandler:
def __init__(self, object on_command):
self._on_command = on_command
self._context = zmq.Context.instance()
self._shutdown_event = Event()
self._pull_socket = self._context.socket(zmq.PULL)
self._pull_socket.setsockopt(zmq.LINGER, 0)
self._pull_socket.bind("tcp://*:5127")
self._pull_thread = Thread(target=self._pull_loop, daemon=True)
self._push_queue = Queue()
self._push_socket = self._context.socket(zmq.PUSH)
self._push_socket.setsockopt(zmq.LINGER, 0)
self._push_socket.bind("tcp://*:5128")
self._push_thread = Thread(target=self._push_loop, daemon=True)
cdef start(self):
self._pull_thread.start()
self._push_thread.start()
cdef _pull_loop(self):
while not self._shutdown_event.is_set():
print('wait for the command...')
message = self._pull_socket.recv()
cmd = RemoteCommand.from_msgpack(<bytes>message)
print(f'received: {cmd}')
self._on_command(cmd)
cdef _push_loop(self):
while not self._shutdown_event.is_set():
try:
response = self._push_queue.get(timeout=1) # Timeout to check shutdown flag
self._push_socket.send(response)
except:
continue
cdef send(self, bytes message_bytes):
print(f'about to send {len(message_bytes)}')
try:
self._push_queue.put(message_bytes)
except Exception as e:
print(e)
cdef close(self):
self._shutdown_event.set()
self._pull_socket.close()
self._push_socket.close()
self._context.term()
cdef class QueueConfig:
cdef str host,
cdef int port, command_port
cdef str producer_user, producer_pw, consumer_user, consumer_pw
@staticmethod
cdef QueueConfig from_json(str json_string):
s = str(json_string).strip()
cdef dict config_dict = json.loads(s)["QueueConfig"]
cdef QueueConfig config = QueueConfig()
config.host = config_dict["Host"]
config.port = config_dict["Port"]
config.command_port = config_dict["CommandsPort"]
config.producer_user = config_dict["ProducerUsername"]
config.producer_pw = config_dict["ProducerPassword"]
config.consumer_user = config_dict["ConsumerUsername"]
config.consumer_pw = config_dict["ConsumerPassword"]
return config
-20
View File
@@ -1,20 +0,0 @@
from annotation cimport Annotation
cdef class SocketHandler:
cdef object on_message
cdef object _socket
cdef object _connection
cdef start(self)
cdef start_inner(self)
cdef send(self, list[Annotation] message)
cdef close(self)
cdef class RabbitHandler:
cdef object on_message
cdef object annotation_producer
cdef object command_consumer
cdef send(self, object message)
cdef start(self)
cdef close(self)
-116
View File
@@ -1,116 +0,0 @@
import asyncio
import json
import socket
import struct
import threading
import msgpack
from msgpack import packb
from rstream import Producer, Consumer, AMQPMessage, ConsumerOffsetSpecification, OffsetType, MessageContext
cimport constants
from api_client cimport ApiClient
from processor_command cimport FileCommand, ProcessorType
from annotation cimport Annotation
cdef class QueueConfig:
cdef str host,
cdef int port
cdef str producer_user, producer_pw, consumer_user, consumer_pw
@staticmethod
cdef QueueConfig from_json(str json_string):
cdef dict config_dict = json.loads(<str>json_string)["QueueConfig"]
cdef QueueConfig config = QueueConfig()
config.host = config_dict["Host"]
config.port = config_dict["Port"]
config.producer_user = config_dict["ProducerUsername"]
config.producer_pw = config_dict["ProducerPassword"]
config.consumer_user = config_dict["ConsumerUsername"]
config.consumer_pw = config_dict["ConsumerPassword"]
return config
cdef class SocketHandler:
"""Handles socket communication with size-prefixed messages."""
def __init__(self, object on_message):
self.on_message = on_message
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.bind((constants.SOCKET_HOST, constants.SOCKET_PORT))
self._socket.listen(1)
cdef start(self):
threading.Thread(target=self.start_inner, daemon=True).start()
cdef start_inner(self):
while True:
self._connection, client_address = self._socket.accept()
size_data = self._connection.recv(4)
if not size_data:
raise ConnectionError("Connection closed while reading size prefix.")
data_size = struct.unpack('>I', size_data)[0]
# Read the full message
data = b""
while len(data) < data_size:
packet = self._socket.recv(data_size - len(data))
if not packet:
raise ConnectionError("Connection closed while reading data.")
data += packet
cmd = FileCommand.from_msgpack(data, ProcessorType.SOCKET)
self.on_message(cmd)
cdef send(self, list[Annotation] message):
data = msgpack.packb(message)
size_prefix = len(data).to_bytes(4, 'big')
self._connection.sendall(size_prefix + data)
cdef close(self):
if self._socket:
self._socket.close()
self._socket = None
cdef class RabbitHandler:
def __init__(self, ApiClient api_client, object on_message):
self.on_message = on_message
cdef str config_str = api_client.load_queue_config()
queue_config = QueueConfig.from_json(config_str)
self.annotation_producer = Producer(
host=<str>queue_config.host,
port=queue_config.port,
username=<str>queue_config.producer_user,
password=<str>queue_config.producer_pw
)
self.command_consumer = Consumer(
host=<str>queue_config.host,
port=queue_config.port,
username=<str>queue_config.consumer_user,
password=<str>queue_config.consumer_pw
)
cdef start(self):
threading.Thread(target=self._run_async, daemon=True).start()
def _run_async(self):
asyncio.run(self.start_inner())
async def start_inner(self):
await self.command_consumer.start()
await self.command_consumer.subscribe(stream=<str>constants.COMMANDS_QUEUE, callback=self.on_message_inner,
offset_specification=ConsumerOffsetSpecification(OffsetType.FIRST, None)) # put real offset
def on_message_inner(self, message: AMQPMessage, message_context: MessageContext):
cdef bytes body = message.body
cmd = FileCommand.from_msgpack(body, ProcessorType.RABBIT)
self.on_message(cmd)
cdef send(self, object message):
packed_message = AMQPMessage(body=packb(message))
self.annotation_producer.send(<str>constants.ANNOTATIONS_QUEUE, packed_message)
cdef close(self):
if self.annotation_producer:
self.annotation_producer.close()
if self.command_consumer:
self.command_consumer.close()
+8 -5
View File
@@ -1,10 +1,11 @@
import base64
import hashlib
import os
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.hashes import Hash, SHA256
from hashlib import sha384
import base64
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
BUFFER_SIZE = 64 * 1024 # 64 KB
@@ -38,7 +39,9 @@ cdef class Security:
decrypted_chunk = decryptor.update(chunk)
res.extend(decrypted_chunk)
res.extend(decryptor.finalize())
return res
unpadder = padding.PKCS7(128).unpadder() # AES block size is 128 bits (16 bytes)
return unpadder.update(res) + unpadder.finalize()
@staticmethod
cdef calc_hash(str key):
+3 -2
View File
@@ -6,10 +6,11 @@ extensions = [
Extension('annotation', ['annotation.pyx']),
Extension('security', ['security.pyx']),
Extension('hardware_service', ['hardware_service.pyx'], extra_compile_args=["-g"], extra_link_args=["-g"]),
Extension('processor_command', ['processor_command.pyx']),
Extension('remote_command', ['remote_command.pyx']),
Extension('remote_command_handler', ['remote_command_handler.pyx']),
Extension('api_client', ['api_client.pyx']),
Extension('inference', ['inference.pyx']),
Extension('remote_handlers', ['remote_handlers.pyx']),
Extension('main', ['main.pyx']),
]
+1 -1
View File
@@ -15,7 +15,7 @@
<PackageReference Include="Microsoft.Extensions.Options.ConfigurationExtensions" Version="9.0.0" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="RabbitMQ.Stream.Client" Version="1.8.9" />
<PackageReference Include="System.Drawing.Common" Version="4.7.3" />
<PackageReference Include="System.Drawing.Common" Version="5.0.3" />
</ItemGroup>
<ItemGroup>
@@ -7,6 +7,9 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="MessagePack" Version="3.1.0" />
<PackageReference Include="MessagePack.Annotations" Version="3.1.0" />
<PackageReference Include="NetMQ" Version="4.0.1.13" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="System.IdentityModel.Tokens.Jwt" Version="8.3.0" />
</ItemGroup>
@@ -0,0 +1,24 @@
using MessagePack;
namespace Azaion.CommonSecurity.DTO.Commands;
[MessagePackObject]
public class FileCommand
{
[Key("CommandType")]
public CommandType CommandType { get; set; }
[Key("Filename")]
public string Filename { get; set; }
[Key("Data")]
public byte[] Data { get; set; }
}
public enum CommandType
{
None = 0,
Inference = 1,
Load = 2
}
@@ -5,6 +5,7 @@ public class SecurityConstants
public const string CONFIG_PATH = "config.json";
public const string DUMMY_DIR = "dummy";
#region ApiConfig
public const string DEFAULT_API_URL = "https://api.azaion.com/";
@@ -16,4 +17,11 @@ public class SecurityConstants
public const string CLAIM_ROLE = "role";
#endregion ApiConfig
#region SocketClient
public const string SOCKET_HOST = "127.0.0.1";
public const int SOCKET_SEND_PORT = 5127;
public const int SOCKET_RECEIVE_PORT = 5128;
#endregion SocketClient
}
@@ -1,4 +1,8 @@
using Azaion.CommonSecurity.DTO;
using Azaion.CommonSecurity.DTO.Commands;
using MessagePack;
using NetMQ;
using NetMQ.Sockets;
namespace Azaion.CommonSecurity.Services;
@@ -7,6 +11,40 @@ public interface IResourceLoader
Task<MemoryStream> Load(string fileName, CancellationToken cancellationToken = default);
}
public class PythonResourceLoader : IResourceLoader
{
private readonly PushSocket _pushSocket = new();
private readonly PullSocket _pullSocket = new();
public PythonResourceLoader(ApiCredentials credentials)
{
//Run python by credentials
_pushSocket.Connect($"tcp://{SecurityConstants.SOCKET_HOST}:{SecurityConstants.SOCKET_SEND_PORT}");
_pullSocket.Connect($"tcp://{SecurityConstants.SOCKET_HOST}:{SecurityConstants.SOCKET_RECEIVE_PORT}");
}
public async Task<MemoryStream> Load(string fileName, CancellationToken cancellationToken = default)
{
try
{
var b = MessagePackSerializer.Serialize(new FileCommand
{
CommandType = CommandType.Load,
Filename = fileName
});
_pushSocket.SendFrame(b);
var bytes = _pullSocket.ReceiveFrameBytes(out bool more);
return new MemoryStream(bytes);
}
catch (Exception ex)
{
throw new Exception($"Failed to load fil0e '{fileName}': {ex.Message}", ex);
}
}
}
public class ResourceLoader(AzaionApiClient api, ApiCredentials credentials) : IResourceLoader
{
public async Task<MemoryStream> Load(string fileName, CancellationToken cancellationToken = default)
+5 -1
View File
@@ -1,5 +1,7 @@
using System.IO;
using System.Reflection;
using System.Text;
using System.Text.Json;
using System.Windows;
using System.Windows.Threading;
using Azaion.Annotator;
@@ -21,6 +23,7 @@ using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Newtonsoft.Json;
using Serilog;
using KeyEventArgs = System.Windows.Input.KeyEventArgs;
@@ -62,8 +65,9 @@ public partial class App
login.CredentialsEntered += async (s, args) =>
{
_apiClient = AzaionApiClient.Create(args);
_resourceLoader = new ResourceLoader(_apiClient, args);
_resourceLoader = new PythonResourceLoader(args);
_securedConfig = await _resourceLoader.Load("secured-config.json");
AppDomain.CurrentDomain.AssemblyResolve += (_, a) =>
{
var assemblyName = a.Name.Split(',').First();