Files
ai-training/train.py
T
Alex Bezdieniezhnykh bb1dbfe1e7 fix folder names, refactoring
move to yolov8m
use checkpoint.txt instead of yaml
2024-06-29 22:18:34 +03:00

117 lines
4.0 KiB
Python

from os import path, replace, remove, listdir, makedirs
from os.path import abspath
import shutil
from datetime import datetime
from pathlib import Path
from ultralytics import YOLO
from constants import processed_images_dir, processed_labels_dir, annotation_classes, prefix, date_format, datasets_dir, models_dir
latest_model = path.join(models_dir, f'{prefix}latest.pt')
today_folder = f'{prefix}{datetime.now():{date_format}}'
today_dataset = path.join(datasets_dir, today_folder)
train_set = 70
valid_set = 20
test_set = 10
def form_dataset():
makedirs(today_dataset, exist_ok=True)
images = listdir(processed_images_dir)
train_size = int(len(images) * train_set / 100.0)
valid_size = int(len(images) * valid_set / 100.0)
move_annotations(images[:train_size], 'train')
move_annotations(images[train_size:train_size + valid_size], 'valid')
move_annotations(images[train_size + valid_size:], 'test')
create_yaml()
def move_annotations(images, folder):
destination_images = path.join(today_dataset, folder, 'images')
makedirs(destination_images, exist_ok=True)
destination_labels = path.join(today_dataset, folder, 'labels')
makedirs(destination_labels, exist_ok=True)
for image_name in images:
image_path = path.join(processed_images_dir, image_name)
label_name = f'{Path(image_name).stem}.txt'
label_path = path.join(processed_labels_dir, label_name)
if not check_label(label_path):
remove(image_path)
else:
replace(image_path, path.join(destination_images, image_name))
replace(label_path, path.join(destination_labels, label_name))
def check_label(label_path):
lines_edited = False
if not path.exists(label_path):
return False
with open(label_path, 'r') as f:
lines = f.readlines()
for line in lines:
for val in line.split(' ')[1:]:
if float(val) > 1:
lines.remove(line)
lines_edited = True
if len(lines) == 0:
return False
if not lines_edited:
return True
with open(label_path, 'w') as label_write:
label_write.writelines(lines)
label_write.close()
return True
def create_yaml():
lines = ['names:']
for c in annotation_classes:
lines.append(f'- {annotation_classes[c].name}')
lines.append(f'nc: {len(annotation_classes)}')
lines.append(f'test: test/images')
lines.append(f'train: train/images')
lines.append(f'val: valid/images')
lines.append('')
today_yaml = abspath(path.join(today_dataset, 'data.yaml'))
with open(today_yaml, 'w', encoding='utf-8') as f:
f.writelines([f'{line}\n' for line in lines])
def revert_to_processed_data(date):
def revert_dir(src_dir, dest_dir):
for file in listdir(src_dir):
s = path.join(src_dir, file)
d = path.join(dest_dir, file)
replace(s, d)
date_dataset = path.join(datasets_dir, f'{prefix}{date}')
makedirs(processed_images_dir, exist_ok=True)
makedirs(processed_labels_dir, exist_ok=True)
for subset in ['test', 'train', 'valid']:
revert_dir(path.join(date_dataset, subset, 'images'), processed_images_dir)
revert_dir(path.join(date_dataset, subset, 'labels'), processed_labels_dir)
shutil.rmtree(date_dataset)
if __name__ == '__main__':
# form_dataset()
model_name = latest_model if path.isfile(latest_model) else 'yolov8m.yaml'
print(f'Initial model: {model_name}')
model = YOLO(model_name)
# cur_folder = path.join(datasets_dir, f'{prefix}2024-06-18')
cur_folder = today_dataset
yaml = abspath(path.join(cur_folder, 'data.yaml'))
results = model.train(data=yaml, epochs=100, batch=55, imgsz=640, save_period=1)
shutil.copy(f'{results.save_dir}/weights/best.pt', latest_model)
shutil.copytree(results.save_dir, path.join(models_dir, cur_folder))
shutil.rmtree('runs')
shutil.rmtree(path.join(models_dir, f'{prefix}latest'))