mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 08:56:35 +00:00
add train.py
form dataset for current date add exception catching
This commit is contained in:
@@ -1,4 +1,86 @@
|
||||
import os
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics import YOLO
|
||||
from constants import current_images_dir, current_labels_dir, annotation_classes, today_dataset, prefix
|
||||
|
||||
yaml_name = 'data.yaml'
|
||||
yaml_path = os.path.join(today_dataset, yaml_name)
|
||||
train_set = 70
|
||||
valid_set = 20
|
||||
test_set = 10
|
||||
|
||||
|
||||
current_dataset_dir = os.path.join('datasets', 'zombobase-current')
|
||||
def form_dataset():
|
||||
os.makedirs(today_dataset, exist_ok=True)
|
||||
images = os.listdir(current_images_dir)
|
||||
|
||||
train_size = int(len(images) * train_set / 100.0)
|
||||
valid_size = int(len(images) * valid_set / 100.0)
|
||||
|
||||
move_annotations(images[:train_size], 'train')
|
||||
move_annotations(images[train_size:train_size + valid_size], 'valid')
|
||||
move_annotations(images[train_size + valid_size:], 'test')
|
||||
|
||||
create_yaml()
|
||||
|
||||
|
||||
def move_annotations(images, folder):
|
||||
destination_images = os.path.join(today_dataset, folder, 'images')
|
||||
os.makedirs(destination_images, exist_ok=True)
|
||||
destination_labels = os.path.join(today_dataset, folder, 'labels')
|
||||
os.makedirs(destination_labels, exist_ok=True)
|
||||
for image_name in images:
|
||||
image_path = os.path.join(current_images_dir, image_name)
|
||||
label_name = f'{Path(image_name).stem}.txt'
|
||||
label_path = os.path.join(current_labels_dir, label_name)
|
||||
os.replace(image_path, os.path.join(destination_images, image_name))
|
||||
os.replace(label_path, os.path.join(destination_labels, label_name))
|
||||
|
||||
|
||||
def create_yaml():
|
||||
lines = ['names:']
|
||||
for c in annotation_classes:
|
||||
lines.append(f'- {annotation_classes[c].name}')
|
||||
lines.append(f'nc: {len(annotation_classes)}')
|
||||
lines.append(f'test: test/images')
|
||||
lines.append(f'train: train/images')
|
||||
lines.append(f'val: valid/images')
|
||||
lines.append('')
|
||||
|
||||
with open(yaml_path, 'w', encoding='utf-8') as f:
|
||||
f.writelines([f'{line}\n' for line in lines])
|
||||
|
||||
|
||||
def get_recent_model():
|
||||
date_sets = []
|
||||
datasets = [next((file for file in os.listdir(os.path.join('datasets', d)) if file.endswith('pt')), None)
|
||||
for d in os.listdir('datasets')]
|
||||
|
||||
# date_str = d.replace(prefix, '')
|
||||
# if date_str == 'current' or date_str == f'{datetime.now():%Y-%m-%d}':
|
||||
# continue
|
||||
# if len(date_sets) == 0:
|
||||
# return None
|
||||
|
||||
recent = max(date_sets)
|
||||
return os.path.join('datasets', f'{prefix}{recent}', f'{prefix}{recent}.pt')
|
||||
|
||||
|
||||
def retrain():
|
||||
model = YOLO(get_recent_model() or 'yolov10x.yaml')
|
||||
model.train(data=yaml_path, save=True, cache=True)
|
||||
|
||||
|
||||
def revert_to_current(date):
|
||||
def revert_dir(dir):
|
||||
os.listdir(os.path.join(current_images_dir, 'images'))
|
||||
|
||||
date_dataset = f'{prefix}{date}'
|
||||
revert_dir(os.path.join(date_dataset, 'test'))
|
||||
|
||||
form_dataset()
|
||||
create_yaml()
|
||||
retrain()
|
||||
|
||||
Reference in New Issue
Block a user