mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 07:06:36 +00:00
add train.py
form dataset for current date add exception catching
This commit is contained in:
@@ -0,0 +1,13 @@
|
|||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from dto.annotationClass import AnnotationClass
|
||||||
|
|
||||||
|
current_dataset_dir = os.path.join('datasets', 'zombobase-current')
|
||||||
|
current_images_dir = os.path.join(current_dataset_dir, 'images')
|
||||||
|
current_labels_dir = os.path.join(current_dataset_dir, 'labels')
|
||||||
|
annotation_classes = AnnotationClass.read_json()
|
||||||
|
|
||||||
|
|
||||||
|
prefix = 'zombobase-'
|
||||||
|
today_dataset = os.path.join('datasets', f'{prefix}{datetime.now():%Y-%m-%d}')
|
||||||
+27
-18
@@ -3,15 +3,11 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
import cv2
|
import cv2
|
||||||
from dto.annotationClass import AnnotationClass
|
from constants import current_images_dir, current_labels_dir, annotation_classes
|
||||||
from dto.imageLabel import ImageLabel
|
from dto.imageLabel import ImageLabel
|
||||||
|
|
||||||
labels_dir = 'labels'
|
labels_dir = 'labels'
|
||||||
images_dir = 'images'
|
images_dir = 'images'
|
||||||
current_dataset_dir = os.path.join('datasets', 'zombobase-current')
|
|
||||||
current_images_dir = os.path.join(current_dataset_dir, 'images')
|
|
||||||
current_labels_dir = os.path.join(current_dataset_dir, 'labels')
|
|
||||||
annotation_classes = AnnotationClass.read_json()
|
|
||||||
|
|
||||||
|
|
||||||
def image_processing(img_ann: ImageLabel) -> [ImageLabel]:
|
def image_processing(img_ann: ImageLabel) -> [ImageLabel]:
|
||||||
@@ -37,16 +33,19 @@ def image_processing(img_ann: ImageLabel) -> [ImageLabel]:
|
|||||||
|
|
||||||
results = []
|
results = []
|
||||||
for i, transform in enumerate(transforms):
|
for i, transform in enumerate(transforms):
|
||||||
res = transform(image=img_ann.image, bboxes=img_ann.labels)
|
try:
|
||||||
path = Path(img_ann.image_path)
|
res = transform(image=img_ann.image, bboxes=img_ann.labels)
|
||||||
name = f'{path.stem}_{i+1}'
|
path = Path(img_ann.image_path)
|
||||||
img = ImageLabel(
|
name = f'{path.stem}_{i+1}'
|
||||||
image=res['image'],
|
img = ImageLabel(
|
||||||
labels=res['bboxes'],
|
image=res['image'],
|
||||||
image_path=os.path.join(current_images_dir, f'{name}{path.suffix}'),
|
labels=res['bboxes'],
|
||||||
labels_path=os.path.join(current_labels_dir, f'{name}.txt')
|
image_path=os.path.join(current_images_dir, f'{name}{path.suffix}'),
|
||||||
)
|
labels_path=os.path.join(current_labels_dir, f'{name}.txt')
|
||||||
results.append(img)
|
)
|
||||||
|
results.append(img)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Error during transformtation: {e}')
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@@ -74,7 +73,7 @@ def read_labels(labels_path) -> [[]]:
|
|||||||
for row in rows:
|
for row in rows:
|
||||||
str_coordinates = row.split(' ')
|
str_coordinates = row.split(' ')
|
||||||
class_num = str_coordinates.pop(0)
|
class_num = str_coordinates.pop(0)
|
||||||
coordinates = [float(n) for n in str_coordinates]
|
coordinates = [float(n.replace(',', '.')) for n in str_coordinates]
|
||||||
coordinates.append(class_num)
|
coordinates.append(class_num)
|
||||||
arr.append(coordinates)
|
arr.append(coordinates)
|
||||||
return arr
|
return arr
|
||||||
@@ -111,8 +110,18 @@ def main():
|
|||||||
labels_path=labels_path,
|
labels_path=labels_path,
|
||||||
labels=read_labels(labels_path)
|
labels=read_labels(labels_path)
|
||||||
))
|
))
|
||||||
except FileNotFoundError:
|
except Exception as e:
|
||||||
print(f'No labels file {labels_path} found')
|
print(f'Error appeared {e}')
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.remove(image_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.remove(labels_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@@ -1,19 +1,16 @@
|
|||||||
from pathlib import Path
|
|
||||||
import cv2
|
import cv2
|
||||||
import os.path
|
|
||||||
|
|
||||||
from dto.annotationClass import AnnotationClass
|
from dto.annotationClass import AnnotationClass
|
||||||
from dto.imageLabel import ImageLabel
|
from dto.imageLabel import ImageLabel
|
||||||
from preprocessing import read_labels
|
from preprocessing import read_labels
|
||||||
|
|
||||||
images_dir = '../images'
|
|
||||||
labels_dir = '../labels'
|
|
||||||
annotation_classes = AnnotationClass.read_json()
|
annotation_classes = AnnotationClass.read_json()
|
||||||
|
|
||||||
|
images_dir = ''
|
||||||
|
|
||||||
image = os.listdir(images_dir)[0]
|
image_path = 'test01.jpg'
|
||||||
image_path = os.path.join(images_dir, image)
|
labels_path = 'test01.txt'
|
||||||
labels_path = os.path.join(labels_dir, f'{Path(image_path).stem}.txt')
|
|
||||||
|
|
||||||
img = ImageLabel(
|
img = ImageLabel(
|
||||||
image_path=image_path,
|
image_path=image_path,
|
||||||
|
|||||||
Binary file not shown.
|
After Width: | Height: | Size: 105 KiB |
@@ -0,0 +1 @@
|
|||||||
|
0 0.3809 0.49269 0.21636 0.39129
|
||||||
@@ -1,4 +1,86 @@
|
|||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from ultralytics import YOLO
|
||||||
|
from constants import current_images_dir, current_labels_dir, annotation_classes, today_dataset, prefix
|
||||||
|
|
||||||
|
yaml_name = 'data.yaml'
|
||||||
|
yaml_path = os.path.join(today_dataset, yaml_name)
|
||||||
|
train_set = 70
|
||||||
|
valid_set = 20
|
||||||
|
test_set = 10
|
||||||
|
|
||||||
|
|
||||||
current_dataset_dir = os.path.join('datasets', 'zombobase-current')
|
def form_dataset():
|
||||||
|
os.makedirs(today_dataset, exist_ok=True)
|
||||||
|
images = os.listdir(current_images_dir)
|
||||||
|
|
||||||
|
train_size = int(len(images) * train_set / 100.0)
|
||||||
|
valid_size = int(len(images) * valid_set / 100.0)
|
||||||
|
|
||||||
|
move_annotations(images[:train_size], 'train')
|
||||||
|
move_annotations(images[train_size:train_size + valid_size], 'valid')
|
||||||
|
move_annotations(images[train_size + valid_size:], 'test')
|
||||||
|
|
||||||
|
create_yaml()
|
||||||
|
|
||||||
|
|
||||||
|
def move_annotations(images, folder):
|
||||||
|
destination_images = os.path.join(today_dataset, folder, 'images')
|
||||||
|
os.makedirs(destination_images, exist_ok=True)
|
||||||
|
destination_labels = os.path.join(today_dataset, folder, 'labels')
|
||||||
|
os.makedirs(destination_labels, exist_ok=True)
|
||||||
|
for image_name in images:
|
||||||
|
image_path = os.path.join(current_images_dir, image_name)
|
||||||
|
label_name = f'{Path(image_name).stem}.txt'
|
||||||
|
label_path = os.path.join(current_labels_dir, label_name)
|
||||||
|
os.replace(image_path, os.path.join(destination_images, image_name))
|
||||||
|
os.replace(label_path, os.path.join(destination_labels, label_name))
|
||||||
|
|
||||||
|
|
||||||
|
def create_yaml():
|
||||||
|
lines = ['names:']
|
||||||
|
for c in annotation_classes:
|
||||||
|
lines.append(f'- {annotation_classes[c].name}')
|
||||||
|
lines.append(f'nc: {len(annotation_classes)}')
|
||||||
|
lines.append(f'test: test/images')
|
||||||
|
lines.append(f'train: train/images')
|
||||||
|
lines.append(f'val: valid/images')
|
||||||
|
lines.append('')
|
||||||
|
|
||||||
|
with open(yaml_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.writelines([f'{line}\n' for line in lines])
|
||||||
|
|
||||||
|
|
||||||
|
def get_recent_model():
|
||||||
|
date_sets = []
|
||||||
|
datasets = [next((file for file in os.listdir(os.path.join('datasets', d)) if file.endswith('pt')), None)
|
||||||
|
for d in os.listdir('datasets')]
|
||||||
|
|
||||||
|
# date_str = d.replace(prefix, '')
|
||||||
|
# if date_str == 'current' or date_str == f'{datetime.now():%Y-%m-%d}':
|
||||||
|
# continue
|
||||||
|
# if len(date_sets) == 0:
|
||||||
|
# return None
|
||||||
|
|
||||||
|
recent = max(date_sets)
|
||||||
|
return os.path.join('datasets', f'{prefix}{recent}', f'{prefix}{recent}.pt')
|
||||||
|
|
||||||
|
|
||||||
|
def retrain():
|
||||||
|
model = YOLO(get_recent_model() or 'yolov10x.yaml')
|
||||||
|
model.train(data=yaml_path, save=True, cache=True)
|
||||||
|
|
||||||
|
|
||||||
|
def revert_to_current(date):
|
||||||
|
def revert_dir(dir):
|
||||||
|
os.listdir(os.path.join(current_images_dir, 'images'))
|
||||||
|
|
||||||
|
date_dataset = f'{prefix}{date}'
|
||||||
|
revert_dir(os.path.join(date_dataset, 'test'))
|
||||||
|
|
||||||
|
form_dataset()
|
||||||
|
create_yaml()
|
||||||
|
retrain()
|
||||||
|
|||||||
@@ -0,0 +1,40 @@
|
|||||||
|
# Parameters
|
||||||
|
nc: 50 # number of classes
|
||||||
|
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
|
||||||
|
# [depth, width, max_channels]
|
||||||
|
x: [1.00, 1.25, 512]
|
||||||
|
|
||||||
|
# YOLOv8.0n backbone
|
||||||
|
backbone:
|
||||||
|
# [from, repeats, module, args]
|
||||||
|
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
||||||
|
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
||||||
|
- [-1, 3, C2f, [128, True]]
|
||||||
|
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
||||||
|
- [-1, 6, C2f, [256, True]]
|
||||||
|
- [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
|
||||||
|
- [-1, 6, C2fCIB, [512, True]]
|
||||||
|
- [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
|
||||||
|
- [-1, 3, C2fCIB, [1024, True]]
|
||||||
|
- [-1, 1, SPPF, [1024, 5]] # 9
|
||||||
|
- [-1, 1, PSA, [1024]] # 10
|
||||||
|
|
||||||
|
# YOLOv8.0n head
|
||||||
|
head:
|
||||||
|
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
|
||||||
|
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
||||||
|
- [-1, 3, C2fCIB, [512, True]] # 13
|
||||||
|
|
||||||
|
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
|
||||||
|
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
||||||
|
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
|
||||||
|
|
||||||
|
- [-1, 1, Conv, [256, 3, 2]]
|
||||||
|
- [[-1, 13], 1, Concat, [1]] # cat head P4
|
||||||
|
- [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium)
|
||||||
|
|
||||||
|
- [-1, 1, SCDown, [512, 3, 2]]
|
||||||
|
- [[-1, 10], 1, Concat, [1]] # cat head P5
|
||||||
|
- [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large)
|
||||||
|
|
||||||
|
- [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)
|
||||||
Reference in New Issue
Block a user