Files
ai-training/train.py
T
Alex Bezdieniezhnykh 6e5153ffb7 prepare train.py to automated training
set cryptography lib to the certain version
add manual_run for manual operations, right now it is onnx conversion and upload
2025-05-31 14:50:58 +03:00

183 lines
5.7 KiB
Python

import concurrent.futures
import os
import random
import shutil
import subprocess
from datetime import datetime
from os import path, replace, listdir, makedirs, scandir
from os.path import abspath
from pathlib import Path
from time import sleep
import yaml
from ultralytics import YOLO
import constants
from api_client import ApiCredentials, ApiClient
from cdn_manager import CDNCredentials, CDNManager
from constants import (processed_images_dir,
processed_labels_dir,
prefix, date_format,
datasets_dir, models_dir,
corrupted_images_dir, corrupted_labels_dir, sample_dir)
from dto.annotationClass import AnnotationClass
from inference.onnx_engine import OnnxEngine
from security import Security
from utils import Dotdict
from exports import export_tensorrt, upload_model, export_onnx
today_folder = f'{prefix}{datetime.now():{date_format}}'
today_dataset = path.join(datasets_dir, today_folder)
train_set = 70
valid_set = 20
test_set = 10
old_images_percentage = 75
DEFAULT_CLASS_NUM = 80
total_files_copied = 0
def form_dataset():
shutil.rmtree(today_dataset, ignore_errors=True)
makedirs(today_dataset)
images = []
with scandir(processed_images_dir) as imd:
for image_file in imd:
if not image_file.is_file():
continue
images.append(image_file)
print(f'Got {len(images)} images. Start shuffling...')
random.shuffle(images)
train_size = int(len(images) * train_set / 100.0)
valid_size = int(len(images) * valid_set / 100.0)
print(f'Start copying...')
copy_annotations(images[:train_size], 'train')
copy_annotations(images[train_size:train_size + valid_size], 'valid')
copy_annotations(images[train_size + valid_size:], 'test')
def copy_annotations(images, folder):
global total_files_copied
total_files_copied = 0
def copy_image(image):
global total_files_copied
total_files_copied += 1
label_name = f'{Path(image.path).stem}.txt'
label_path = path.join(processed_labels_dir, label_name)
if check_label(label_path):
shutil.copy(image.path, path.join(destination_images, image.name))
shutil.copy(label_path, path.join(destination_labels, label_name))
else:
shutil.copy(image.path, path.join(corrupted_images_dir, image.name))
shutil.copy(label_path, path.join(corrupted_labels_dir, label_name))
print(f'Label {label_path} is corrupted! Copy with its image to the corrupted directory ({corrupted_labels_dir})')
if total_files_copied % 1000 == 0:
print(f'{total_files_copied} copied...')
destination_images = path.join(today_dataset, folder, 'images')
makedirs(destination_images, exist_ok=True)
destination_labels = path.join(today_dataset, folder, 'labels')
makedirs(destination_labels, exist_ok=True)
makedirs(corrupted_images_dir, exist_ok=True)
makedirs(corrupted_labels_dir, exist_ok=True)
copied = 0
print(f'Copying annotations to {destination_images} and {destination_labels} folders:')
with concurrent.futures.ThreadPoolExecutor() as executor:
executor.map(copy_image, images)
print(f'Copied all {copied} annotations to {destination_images} and {destination_labels} folders')
def check_label(label_path):
if not path.exists(label_path):
return False
with open(label_path, 'r') as f:
lines = f.readlines()
for line in lines:
for val in line.split(' ')[1:]:
if float(val) > 1:
return False
return True
def create_yaml():
print('creating yaml...')
lines = ['names:']
annotation_classes = AnnotationClass.read_json()
for i in range(DEFAULT_CLASS_NUM):
if i in annotation_classes:
lines.append(f'- {annotation_classes[i].name}')
else:
lines.append(f'- Class-{i + 1}')
lines.append(f'nc: {DEFAULT_CLASS_NUM}')
lines.append(f'test: test/images')
lines.append(f'train: train/images')
lines.append(f'val: valid/images')
lines.append('')
today_yaml = abspath(path.join(today_dataset, 'data.yaml'))
with open(today_yaml, 'w', encoding='utf-8') as f:
f.writelines([f'{line}\n' for line in lines])
def resume_training(last_pt_path):
model = YOLO(last_pt_path)
results = model.train(data=yaml,
resume=True,
epochs=120,
batch=11,
imgsz=1280,
save_period=1,
workers=24)
def train_dataset():
form_dataset()
create_yaml()
model_name = 'yolo11m.yaml'
model = YOLO(model_name)
results = model.train(data=abspath(path.join(today_dataset, 'data.yaml')),
epochs=120,
batch=11,
imgsz=1280,
save_period=1,
workers=24)
model_dir = path.join(models_dir, today_folder)
shutil.copytree(results.save_dir, model_dir)
model_path = path.join(models_dir, f'{prefix[:-1]}.pt')
shutil.copy(path.join(model_dir, 'weights', 'best.pt'), model_path)
return model_path
def validate(model_path):
model = YOLO(model_path)
print(model.val())
if __name__ == '__main__':
model_path = train_dataset()
validate(path.join('runs', 'detect', 'train7', 'weights', 'best.pt'))
onnx_path = export_onnx(model_path)
api_client = ApiClient()
with open(onnx_path, 'rb') as binary_file:
onnx_bytes = binary_file.read()
key = Security.get_model_encryption_key()
api_client.upload_big_small_resource(onnx_bytes, onnx_path, constants.MODELS_FOLDER, key)