Files
ai-training/train.py
T
2025-05-22 17:02:24 +03:00

248 lines
8.5 KiB
Python

import concurrent.futures
import os
import random
import shutil
import subprocess
from datetime import datetime
from os import path, replace, listdir, makedirs, scandir
from os.path import abspath
from pathlib import Path
from time import sleep
import yaml
from ultralytics import YOLO
import constants
from api_client import ApiCredentials, ApiClient
from cdn_manager import CDNCredentials, CDNManager
from constants import (processed_images_dir,
processed_labels_dir,
prefix, date_format,
datasets_dir, models_dir,
corrupted_images_dir, corrupted_labels_dir, sample_dir)
from dto.annotationClass import AnnotationClass
from inference.onnx_engine import OnnxEngine
from security import Security
from utils import Dotdict
from exports import export_tensorrt, upload_model, export_onnx
today_folder = f'{prefix}{datetime.now():{date_format}}'
today_dataset = path.join(datasets_dir, today_folder)
train_set = 70
valid_set = 20
test_set = 10
old_images_percentage = 75
DEFAULT_CLASS_NUM = 80
total_files_copied = 0
def form_dataset(from_date: datetime):
makedirs(today_dataset, exist_ok=True)
images = []
old_images = []
with scandir(processed_images_dir) as imd:
for image_file in imd:
if not image_file.is_file():
continue
mod_time = datetime.fromtimestamp(image_file.stat().st_mtime).replace(hour=0, minute=0, second=0, microsecond=0)
if from_date is None:
images.append(image_file)
elif mod_time > from_date:
images.append(image_file)
else: # gather old images as well in order to avoid overfitting on the only new data.
old_images.append(image_file)
random.shuffle(old_images)
old_images_size = int(len(old_images) * old_images_percentage / 100.0)
print(f'Got {len(images)} new images and {old_images_size} of old images (to prevent overfitting). Shuffling them...')
images.extend(old_images[:old_images_size])
random.shuffle(images)
train_size = int(len(images) * train_set / 100.0)
valid_size = int(len(images) * valid_set / 100.0)
copy_annotations(images[:train_size], 'train')
copy_annotations(images[train_size:train_size + valid_size], 'valid')
copy_annotations(images[train_size + valid_size:], 'test')
def copy_annotations(images, folder):
global total_files_copied
total_files_copied = 0
def copy_image(image):
global total_files_copied
total_files_copied += 1
label_name = f'{Path(image.path).stem}.txt'
label_path = path.join(processed_labels_dir, label_name)
if check_label(label_path):
shutil.copy(image.path, path.join(destination_images, image.name))
shutil.copy(label_path, path.join(destination_labels, label_name))
else:
shutil.copy(image.path, path.join(corrupted_images_dir, image.name))
shutil.copy(label_path, path.join(corrupted_labels_dir, label_name))
print(f'Label {label_path} is corrupted! Copy with its image to the corrupted directory ({corrupted_labels_dir})')
if total_files_copied % 1000 == 0:
print(f'{total_files_copied} copied...')
destination_images = path.join(today_dataset, folder, 'images')
makedirs(destination_images, exist_ok=True)
destination_labels = path.join(today_dataset, folder, 'labels')
makedirs(destination_labels, exist_ok=True)
makedirs(corrupted_images_dir, exist_ok=True)
makedirs(corrupted_labels_dir, exist_ok=True)
copied = 0
print(f'Copying annotations to {destination_images} and {destination_labels} folders:')
with concurrent.futures.ThreadPoolExecutor() as executor:
executor.map(copy_image, images)
print(f'Copied all {copied} annotations to {destination_images} and {destination_labels} folders')
def check_label(label_path):
if not path.exists(label_path):
return False
with open(label_path, 'r') as f:
lines = f.readlines()
for line in lines:
for val in line.split(' ')[1:]:
if float(val) > 1:
return False
return True
def create_yaml():
print('creating yaml...')
lines = ['names:']
annotation_classes = AnnotationClass.read_json()
for i in range(DEFAULT_CLASS_NUM):
if i in annotation_classes:
lines.append(f'- {annotation_classes[i].name}')
else:
lines.append(f'- Class-{i + 1}')
lines.append(f'nc: {DEFAULT_CLASS_NUM}')
lines.append(f'test: test/images')
lines.append(f'train: train/images')
lines.append(f'val: valid/images')
lines.append('')
today_yaml = abspath(path.join(today_dataset, 'data.yaml'))
with open(today_yaml, 'w', encoding='utf-8') as f:
f.writelines([f'{line}\n' for line in lines])
def revert_to_processed_data(date):
def revert_dir(src_dir, dest_dir):
for file in listdir(src_dir):
s = path.join(src_dir, file)
d = path.join(dest_dir, file)
replace(s, d)
date_dataset = path.join(datasets_dir, f'{prefix}{date}')
makedirs(processed_images_dir, exist_ok=True)
makedirs(processed_labels_dir, exist_ok=True)
for subset in ['test', 'train', 'valid']:
revert_dir(path.join(date_dataset, subset, 'images'), processed_images_dir)
revert_dir(path.join(date_dataset, subset, 'labels'), processed_labels_dir)
shutil.rmtree(date_dataset)
def get_latest_model():
def convert(d: str):
if not d.startswith(prefix):
return None
dir_date = datetime.strptime(d.replace(prefix, ''), '%Y-%m-%d')
dir_model_path = path.join(models_dir, d, 'weights', 'best.pt')
return {'date': dir_date, 'path': dir_model_path}
dates = [convert(d) for d in next(os.walk(models_dir))[1]]
dates = list(filter(lambda x : x is not None, dates))
sorted_dates = list(sorted(dates, key=lambda x: x['date'] ))
if len(sorted_dates) == 0:
return None, None
last_model = sorted_dates[-1]
return last_model['date'], last_model['path']
def resume_training(last_pt_path):
model = YOLO(last_pt_path)
results = model.train(data=yaml,
resume=True,
epochs=120,
batch=11,
imgsz=1280,
save_period=1,
workers=24)
def train_dataset(existing_date=None, from_scratch=False):
latest_date, latest_model = get_latest_model() if not from_scratch else None, None
if existing_date is not None:
cur_folder = f'{prefix}{existing_date}'
cur_dataset = path.join(datasets_dir, f'{prefix}{existing_date}')
else:
# if from_scratch and Path(today_dataset).exists():
# shutil.rmtree(today_dataset)
# form_dataset(latest_date)
# create_yaml()
cur_folder = today_folder
cur_dataset = today_dataset
model_name = latest_model if latest_model is not None and path.isfile(latest_model) and not from_scratch else 'yolo11m.yaml'
print(f'Initial model: {model_name}')
model = YOLO(model_name)
yaml = abspath(path.join(cur_dataset, 'data.yaml'))
results = model.train(data=yaml,
epochs=120,
batch=11,
imgsz=1280,
save_period=1,
workers=24)
model_dir = path.join(models_dir, cur_folder)
shutil.copytree(results.save_dir, model_dir)
model_path = path.join(models_dir, f'{prefix[:-1]}.pt')
shutil.copy(path.join(model_dir, 'weights', 'best.pt'), model_path)
shutil.rmtree('runs')
return model_path
def convert2rknn():
subprocess.call(['bash', 'convert.sh'], cwd="./orangepi5")
latest_date, latest_model = get_latest_model()
model = YOLO(latest_model)
model.export(format="onnx")
pass
def validate(model_path):
model = YOLO(model_path)
metrics = model.val()
pass
if __name__ == '__main__':
model_path = train_dataset(from_scratch=True)
# validate(path.join('runs', 'detect', 'train7', 'weights', 'best.pt'))
# form_data_sample(500)
# convert2rknn()
api_client = ApiClient()
onnx_path = export_onnx('azaion.pt')
with open(onnx_path, 'rb') as binary_file:
onnx_bytes = binary_file.read()
key = Security.get_model_encryption_key()
api_client.upload_big_small_resource(onnx_bytes, onnx_path, constants.MODELS_FOLDER, key)