mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 07:06:36 +00:00
182 lines
6.1 KiB
Python
182 lines
6.1 KiB
Python
import concurrent.futures
|
|
import glob
|
|
import os
|
|
import random
|
|
import shutil
|
|
import subprocess
|
|
from datetime import datetime
|
|
from os import path, replace, listdir, makedirs, scandir
|
|
from os.path import abspath
|
|
from pathlib import Path
|
|
from time import sleep
|
|
|
|
import yaml
|
|
from ultralytics import YOLO
|
|
|
|
import constants
|
|
from api_client import ApiCredentials, ApiClient
|
|
from cdn_manager import CDNCredentials, CDNManager
|
|
from constants import (processed_images_dir,
|
|
processed_labels_dir,
|
|
prefix, date_format,
|
|
datasets_dir, models_dir,
|
|
corrupted_images_dir, corrupted_labels_dir, sample_dir)
|
|
from dto.annotationClass import AnnotationClass
|
|
from inference.onnx_engine import OnnxEngine
|
|
|
|
from security import Security
|
|
from utils import Dotdict
|
|
from exports import export_tensorrt, upload_model, export_onnx
|
|
|
|
today_folder = f'{prefix}{datetime.now():{date_format}}'
|
|
today_dataset = path.join(datasets_dir, today_folder)
|
|
train_set = 70
|
|
valid_set = 20
|
|
test_set = 10
|
|
old_images_percentage = 75
|
|
|
|
DEFAULT_CLASS_NUM = 80
|
|
total_files_copied = 0
|
|
|
|
|
|
def form_dataset():
|
|
shutil.rmtree(today_dataset, ignore_errors=True)
|
|
makedirs(today_dataset)
|
|
images = []
|
|
with scandir(processed_images_dir) as imd:
|
|
for image_file in imd:
|
|
if not image_file.is_file():
|
|
continue
|
|
images.append(image_file)
|
|
|
|
print(f'Got {len(images)} images. Start shuffling...')
|
|
random.shuffle(images)
|
|
|
|
train_size = int(len(images) * train_set / 100.0)
|
|
valid_size = int(len(images) * valid_set / 100.0)
|
|
|
|
print(f'Start copying...')
|
|
copy_annotations(images[:train_size], 'train')
|
|
copy_annotations(images[train_size:train_size + valid_size], 'valid')
|
|
copy_annotations(images[train_size + valid_size:], 'test')
|
|
|
|
|
|
def copy_annotations(images, folder):
|
|
global total_files_copied
|
|
total_files_copied = 0
|
|
|
|
def copy_image(image):
|
|
global total_files_copied
|
|
total_files_copied += 1
|
|
label_name = f'{Path(image.path).stem}.txt'
|
|
label_path = path.join(processed_labels_dir, label_name)
|
|
if check_label(label_path):
|
|
shutil.copy(image.path, path.join(destination_images, image.name))
|
|
shutil.copy(label_path, path.join(destination_labels, label_name))
|
|
else:
|
|
shutil.copy(image.path, path.join(corrupted_images_dir, image.name))
|
|
shutil.copy(label_path, path.join(corrupted_labels_dir, label_name))
|
|
print(f'Label {label_path} is corrupted! Copy with its image to the corrupted directory ({corrupted_labels_dir})')
|
|
|
|
if total_files_copied % 1000 == 0:
|
|
print(f'{total_files_copied} copied...')
|
|
|
|
destination_images = path.join(today_dataset, folder, 'images')
|
|
makedirs(destination_images, exist_ok=True)
|
|
|
|
destination_labels = path.join(today_dataset, folder, 'labels')
|
|
makedirs(destination_labels, exist_ok=True)
|
|
|
|
makedirs(corrupted_images_dir, exist_ok=True)
|
|
makedirs(corrupted_labels_dir, exist_ok=True)
|
|
|
|
copied = 0
|
|
print(f'Copying annotations to {destination_images} and {destination_labels} folders:')
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
executor.map(copy_image, images)
|
|
|
|
|
|
print(f'Copied all {copied} annotations to {destination_images} and {destination_labels} folders')
|
|
|
|
|
|
def check_label(label_path):
|
|
if not path.exists(label_path):
|
|
return False
|
|
with open(label_path, 'r') as f:
|
|
lines = f.readlines()
|
|
for line in lines:
|
|
for val in line.split(' ')[1:]:
|
|
if float(val) > 1:
|
|
return False
|
|
return True
|
|
|
|
|
|
def create_yaml():
|
|
print('creating yaml...')
|
|
lines = ['names:']
|
|
annotation_classes = AnnotationClass.read_json()
|
|
for i in range(DEFAULT_CLASS_NUM):
|
|
if i in annotation_classes:
|
|
lines.append(f'- {annotation_classes[i].name}')
|
|
else:
|
|
lines.append(f'- Class-{i + 1}')
|
|
lines.append(f'nc: {DEFAULT_CLASS_NUM}')
|
|
|
|
lines.append(f'test: test/images')
|
|
lines.append(f'train: train/images')
|
|
lines.append(f'val: valid/images')
|
|
lines.append('')
|
|
|
|
today_yaml = abspath(path.join(today_dataset, 'data.yaml'))
|
|
with open(today_yaml, 'w', encoding='utf-8') as f:
|
|
f.writelines([f'{line}\n' for line in lines])
|
|
|
|
|
|
|
|
def resume_training(last_pt_path):
|
|
model = YOLO(last_pt_path)
|
|
results = model.train(data=yaml,
|
|
resume=True,
|
|
epochs=120,
|
|
batch=11,
|
|
imgsz=1280,
|
|
save_period=1,
|
|
workers=24)
|
|
|
|
|
|
def train_dataset():
|
|
form_dataset()
|
|
create_yaml()
|
|
model = YOLO('yolo11m.yaml')
|
|
|
|
results = model.train(data=abspath(path.join(today_dataset, 'data.yaml')),
|
|
epochs=120, # Empirically set for good performance and relatively not so long training
|
|
# (360k of annotations on 1 RTX4090 takes 11.5 days of training :( )
|
|
batch=11, # reflects current GPU memory, 24Gb (batch 11 gets ~22Gb, batch 12 fails on 24.2Gb)
|
|
imgsz=1280, # 1280p is a tradeoff between quality and speed
|
|
save_period=1, # for resuming in case of power outages / other issues
|
|
workers=24) # loading data workers. Bound to cpus count
|
|
|
|
model_dir = path.join(models_dir, today_folder)
|
|
|
|
shutil.copytree(results.save_dir, model_dir)
|
|
for file in glob.glob(path.join(model_dir, 'weights', 'epoch*')): # remove unnecessary middle epochs
|
|
os.remove(file)
|
|
shutil.copy(path.join(model_dir, 'weights', 'best.pt'), constants.CURRENT_PT_MODEL)
|
|
|
|
|
|
def export_current_model():
|
|
export_onnx(constants.CURRENT_PT_MODEL)
|
|
api_client = ApiClient()
|
|
with open(constants.CURRENT_ONNX_MODEL, 'rb') as binary_file:
|
|
onnx_bytes = binary_file.read()
|
|
|
|
key = Security.get_model_encryption_key()
|
|
api_client.upload_big_small_resource(onnx_bytes, 'azaion.onnx', constants.MODELS_FOLDER, key)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
train_dataset()
|
|
export_current_model()
|
|
print('success!')
|