Files
ai-training/train.py
T
Alex Bezdieniezhnykh 538dc8efa9 train is ready
manual_run reuses train's export
add current_model to constants
2025-05-31 18:41:10 +03:00

178 lines
5.6 KiB
Python

import concurrent.futures
import os
import random
import shutil
import subprocess
from datetime import datetime
from os import path, replace, listdir, makedirs, scandir
from os.path import abspath
from pathlib import Path
from time import sleep
import yaml
from ultralytics import YOLO
import constants
from api_client import ApiCredentials, ApiClient
from cdn_manager import CDNCredentials, CDNManager
from constants import (processed_images_dir,
processed_labels_dir,
prefix, date_format,
datasets_dir, models_dir,
corrupted_images_dir, corrupted_labels_dir, sample_dir)
from dto.annotationClass import AnnotationClass
from inference.onnx_engine import OnnxEngine
from security import Security
from utils import Dotdict
from exports import export_tensorrt, upload_model, export_onnx
today_folder = f'{prefix}{datetime.now():{date_format}}'
today_dataset = path.join(datasets_dir, today_folder)
train_set = 70
valid_set = 20
test_set = 10
old_images_percentage = 75
DEFAULT_CLASS_NUM = 80
total_files_copied = 0
def form_dataset():
shutil.rmtree(today_dataset, ignore_errors=True)
makedirs(today_dataset)
images = []
with scandir(processed_images_dir) as imd:
for image_file in imd:
if not image_file.is_file():
continue
images.append(image_file)
print(f'Got {len(images)} images. Start shuffling...')
random.shuffle(images)
train_size = int(len(images) * train_set / 100.0)
valid_size = int(len(images) * valid_set / 100.0)
print(f'Start copying...')
copy_annotations(images[:train_size], 'train')
copy_annotations(images[train_size:train_size + valid_size], 'valid')
copy_annotations(images[train_size + valid_size:], 'test')
def copy_annotations(images, folder):
global total_files_copied
total_files_copied = 0
def copy_image(image):
global total_files_copied
total_files_copied += 1
label_name = f'{Path(image.path).stem}.txt'
label_path = path.join(processed_labels_dir, label_name)
if check_label(label_path):
shutil.copy(image.path, path.join(destination_images, image.name))
shutil.copy(label_path, path.join(destination_labels, label_name))
else:
shutil.copy(image.path, path.join(corrupted_images_dir, image.name))
shutil.copy(label_path, path.join(corrupted_labels_dir, label_name))
print(f'Label {label_path} is corrupted! Copy with its image to the corrupted directory ({corrupted_labels_dir})')
if total_files_copied % 1000 == 0:
print(f'{total_files_copied} copied...')
destination_images = path.join(today_dataset, folder, 'images')
makedirs(destination_images, exist_ok=True)
destination_labels = path.join(today_dataset, folder, 'labels')
makedirs(destination_labels, exist_ok=True)
makedirs(corrupted_images_dir, exist_ok=True)
makedirs(corrupted_labels_dir, exist_ok=True)
copied = 0
print(f'Copying annotations to {destination_images} and {destination_labels} folders:')
with concurrent.futures.ThreadPoolExecutor() as executor:
executor.map(copy_image, images)
print(f'Copied all {copied} annotations to {destination_images} and {destination_labels} folders')
def check_label(label_path):
if not path.exists(label_path):
return False
with open(label_path, 'r') as f:
lines = f.readlines()
for line in lines:
for val in line.split(' ')[1:]:
if float(val) > 1:
return False
return True
def create_yaml():
print('creating yaml...')
lines = ['names:']
annotation_classes = AnnotationClass.read_json()
for i in range(DEFAULT_CLASS_NUM):
if i in annotation_classes:
lines.append(f'- {annotation_classes[i].name}')
else:
lines.append(f'- Class-{i + 1}')
lines.append(f'nc: {DEFAULT_CLASS_NUM}')
lines.append(f'test: test/images')
lines.append(f'train: train/images')
lines.append(f'val: valid/images')
lines.append('')
today_yaml = abspath(path.join(today_dataset, 'data.yaml'))
with open(today_yaml, 'w', encoding='utf-8') as f:
f.writelines([f'{line}\n' for line in lines])
def resume_training(last_pt_path):
model = YOLO(last_pt_path)
results = model.train(data=yaml,
resume=True,
epochs=120,
batch=11,
imgsz=1280,
save_period=1,
workers=24)
def train_dataset():
form_dataset()
create_yaml()
model_name = 'yolo11m.yaml'
model = YOLO(model_name)
results = model.train(data=abspath(path.join(today_dataset, 'data.yaml')),
epochs=120,
batch=11,
imgsz=1280,
save_period=1,
workers=24)
model_dir = path.join(models_dir, today_folder)
shutil.copytree(results.save_dir, model_dir)
shutil.copy(path.join(model_dir, 'weights', 'best.pt'), constants.CURRENT_PT_MODEL)
def export_current_model():
export_onnx(constants.CURRENT_PT_MODEL)
api_client = ApiClient()
with open(constants.CURRENT_ONNX_MODEL, 'rb') as binary_file:
onnx_bytes = binary_file.read()
key = Security.get_model_encryption_key()
api_client.upload_big_small_resource(onnx_bytes, 'azaion.onnx', constants.MODELS_FOLDER, key)
if __name__ == '__main__':
train_dataset()
export_current_model()
print('success!')