Files
ai-training/train.py
T
Oleksandr Bezdieniezhnykh a9b777acc4 small refactoring
2024-06-08 19:34:17 +03:00

101 lines
3.5 KiB
Python

import os
import shutil
from datetime import datetime
from pathlib import Path
from ultralytics import YOLOv10
from constants import current_images_dir, current_labels_dir, annotation_classes, prefix
today_dataset = os.path.join('datasets', f'{prefix}{datetime.now():%Y-%m-%d}')
yaml_path = os.path.join(today_dataset, 'data.yaml')
train_set = 70
valid_set = 20
test_set = 10
def form_dataset():
os.makedirs(today_dataset, exist_ok=True)
images = os.listdir(current_images_dir)
train_size = int(len(images) * train_set / 100.0)
valid_size = int(len(images) * valid_set / 100.0)
move_annotations(images[:train_size], 'train')
move_annotations(images[train_size:train_size + valid_size], 'valid')
move_annotations(images[train_size + valid_size:], 'test')
create_yaml()
def move_annotations(images, folder):
destination_images = os.path.join(today_dataset, folder, 'images')
os.makedirs(destination_images, exist_ok=True)
destination_labels = os.path.join(today_dataset, folder, 'labels')
os.makedirs(destination_labels, exist_ok=True)
for image_name in images:
image_path = os.path.join(current_images_dir, image_name)
label_name = f'{Path(image_name).stem}.txt'
label_path = os.path.join(current_labels_dir, label_name)
if not os.path.exists(label_path):
os.remove(image_path)
else:
os.replace(image_path, os.path.join(destination_images, image_name))
os.replace(label_path, os.path.join(destination_labels, label_name))
def create_yaml():
lines = ['names:']
for c in annotation_classes:
lines.append(f'- {annotation_classes[c].name}')
lines.append(f'nc: {len(annotation_classes)}')
main_dir = f'../../{prefix}{datetime.now():%Y-%m-%d}'
lines.append(f'test: {main_dir}/test/images')
lines.append(f'train: {main_dir}/train/images')
lines.append(f'val: {main_dir}/valid/images')
lines.append('')
with open(yaml_path, 'w', encoding='utf-8') as f:
f.writelines([f'{line}\n' for line in lines])
def get_recent_model():
cur_model = None
cur_date = None
for d in os.listdir('datasets'):
date_str = d.replace(prefix, '')
if date_str == 'current' or date_str == f'{datetime.now():%Y-%m-%d}':
continue
date = datetime.strptime(date_str, '%Y-%m-%d')
for file in os.listdir(os.path.join('datasets', d)):
if file.endswith('pt') and (cur_date is None or cur_date < date):
cur_model = os.path.join('datasets', d, file)
cur_date = date
return cur_model
def revert_to_current(date):
def revert_dir(src_dir, dest_dir):
for file in os.listdir(src_dir):
s = os.path.join(src_dir, file)
d = os.path.join(dest_dir, file)
os.replace(s, d)
date_dataset = os.path.join('datasets', f'{prefix}{date}')
current_dataset = os.path.join('datasets', f'{prefix}current')
for subset in ['test', 'train', 'valid']:
revert_dir(os.path.join(date_dataset, subset, 'images'), os.path.join(current_dataset, 'images'))
revert_dir(os.path.join(date_dataset, subset, 'labels'), os.path.join(current_dataset, 'labels'))
shutil.rmtree(date_dataset)
if __name__ == '__main__':
# revert_to_current('2024-06-06')
# form_dataset()
# create_yaml()
model = get_recent_model() or 'yolov10x.yaml'
model = YOLOv10(model=model, task='detect').to('cuda')
results = model.train(data=yaml_path, epochs=200, imgsz=1280, save=True, cache=True)
pass