Files
ai-training/train.py
T
2024-08-23 13:47:54 +03:00

149 lines
5.2 KiB
Python

import random
from os import path, replace, remove, listdir, makedirs, scandir
from os.path import abspath
import shutil
from datetime import datetime
from pathlib import Path
from ultralytics import YOLO
from constants import (processed_images_dir,
processed_labels_dir,
annotation_classes,
prefix, date_format,
datasets_dir, models_dir,
corrupted_images_dir, corrupted_labels_dir)
today_folder = f'{prefix}{datetime.now():{date_format}}'
today_dataset = path.join(datasets_dir, today_folder)
train_set = 70
valid_set = 20
test_set = 10
def form_dataset(from_date: datetime):
makedirs(today_dataset, exist_ok=True)
images = []
with scandir(processed_images_dir) as imd:
for image_file in imd:
if not image_file.is_file():
continue
mod_time = datetime.fromtimestamp(image_file.stat().st_mtime)
if from_date is None:
images.append(image_file)
elif mod_time > from_date:
images.append(image_file)
print('shuffling images')
random.shuffle(images)
train_size = int(len(images) * train_set / 100.0)
valid_size = int(len(images) * valid_set / 100.0)
print(f'copy train dataset, size: {train_size} annotations')
copy_annotations(images[:train_size], 'train')
print(f'copy valid set, size: {valid_size} annotations')
copy_annotations(images[train_size:train_size + valid_size], 'valid')
print(f'copy test set, size: {len(images) - train_size - valid_size} annotations')
copy_annotations(images[train_size + valid_size:], 'test')
print('creating yaml...')
create_yaml()
def copy_annotations(images, folder):
destination_images = path.join(today_dataset, folder, 'images')
makedirs(destination_images, exist_ok=True)
destination_labels = path.join(today_dataset, folder, 'labels')
makedirs(destination_labels, exist_ok=True)
makedirs(corrupted_images_dir, exist_ok=True)
makedirs(corrupted_labels_dir, exist_ok=True)
for image in images:
label_name = f'{Path(image.path).stem}.txt'
label_path = path.join(processed_labels_dir, label_name)
if check_label(label_path):
shutil.copy(image.path, path.join(destination_images, image.name))
shutil.copy(label_path, path.join(destination_labels, label_name))
else:
shutil.copy(image.path, path.join(corrupted_images_dir, image.name))
shutil.copy(label_path, path.join(corrupted_labels_dir, label_name))
print(f'Label {label_path} is corrupted! Copy with its image to the corrupted directory ({corrupted_labels_dir})')
def check_label(label_path):
if not path.exists(label_path):
return False
with open(label_path, 'r') as f:
lines = f.readlines()
for line in lines:
for val in line.split(' ')[1:]:
if float(val) > 1:
return False
return True
def create_yaml():
lines = ['names:']
for c in annotation_classes:
lines.append(f'- {annotation_classes[c].name}')
lines.append(f'nc: {len(annotation_classes)}')
lines.append(f'test: test/images')
lines.append(f'train: train/images')
lines.append(f'val: valid/images')
lines.append('')
today_yaml = abspath(path.join(today_dataset, 'data.yaml'))
with open(today_yaml, 'w', encoding='utf-8') as f:
f.writelines([f'{line}\n' for line in lines])
def revert_to_processed_data(date):
def revert_dir(src_dir, dest_dir):
for file in listdir(src_dir):
s = path.join(src_dir, file)
d = path.join(dest_dir, file)
replace(s, d)
date_dataset = path.join(datasets_dir, f'{prefix}{date}')
makedirs(processed_images_dir, exist_ok=True)
makedirs(processed_labels_dir, exist_ok=True)
for subset in ['test', 'train', 'valid']:
revert_dir(path.join(date_dataset, subset, 'images'), processed_images_dir)
revert_dir(path.join(date_dataset, subset, 'labels'), processed_labels_dir)
shutil.rmtree(date_dataset)
def get_latest_model():
def convert(d: str):
dir_date = datetime.strptime(d.replace(prefix, ''), '%Y-%m-%d')
dir_model_path = path.join(models_dir, d, 'weights', 'best.pt')
return {'date': dir_date, 'path': dir_model_path}
dates = [convert(d) for d in listdir(models_dir)]
sorted_dates = list(sorted(dates, key=lambda x: x['date']))
if len(sorted_dates) == 0:
return None, None
last_model = sorted_dates[-1]
return last_model['date'], last_model['path']
if __name__ == '__main__':
latest_date, latest_model = get_latest_model()
# form_dataset(latest_date)
model_name = latest_model if latest_model is not None and path.isfile(latest_model) else 'yolov8m.yaml'
print(f'Initial model: {model_name}')
model = YOLO(model_name)
# cur_folder = path.join(datasets_dir, f'{prefix}2024-06-18')
cur_folder = today_dataset
yaml = abspath(path.join(cur_folder, 'data.yaml'))
results = model.train(data=yaml, epochs=100, batch=60, imgsz=640)
shutil.copytree(results.save_dir, path.join(models_dir, today_folder))
shutil.rmtree('runs')