mirror of
https://github.com/azaion/ai-training.git
synced 2026-04-22 06:56:34 +00:00
142c6c4de8
- Replaced module-level path variables in constants.py with a structured Pydantic Config class. - Updated all relevant modules (train.py, augmentation.py, exports.py, dataset-visualiser.py, manual_run.py) to access paths through the new config structure. - Fixed bugs related to image processing and model saving. - Enhanced test infrastructure to accommodate the new configuration approach. This refactor improves code maintainability and clarity by centralizing configuration management.
153 lines
6.0 KiB
Python
153 lines
6.0 KiB
Python
import concurrent.futures
|
|
import os.path
|
|
import shutil
|
|
import time
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import albumentations as A
|
|
import cv2
|
|
import numpy as np
|
|
|
|
import constants
|
|
from dto.imageLabel import ImageLabel
|
|
|
|
|
|
class Augmentator:
|
|
def __init__(self):
|
|
self.total_files_processed = 0
|
|
self.total_images_to_process = 0
|
|
|
|
self.correct_margin = 0.0005
|
|
self.correct_min_bbox_size = 0.01
|
|
|
|
self.transform = A.Compose([
|
|
A.HorizontalFlip(p=0.6),
|
|
A.RandomBrightnessContrast(p=0.4, brightness_limit=(-0.3, 0.3), contrast_limit=(-0.05, 0.05)),
|
|
A.Affine(p=0.8, scale=(0.8, 1.2), rotate=(-35, 35), shear=(-10, 10)),
|
|
|
|
A.MotionBlur(p=0.1, blur_limit=(1, 2)),
|
|
A.HueSaturationValue(p=0.4, hue_shift_limit=10, sat_shift_limit=10, val_shift_limit=10)
|
|
], bbox_params=A.BboxParams(format='yolo'))
|
|
|
|
def correct_bboxes(self, labels):
|
|
res = []
|
|
for bboxes in labels:
|
|
x = bboxes[0]
|
|
y = bboxes[1]
|
|
half_width = 0.5*bboxes[2]
|
|
half_height = 0.5*bboxes[3]
|
|
|
|
# calc how much bboxes are outside borders ( +small margin ).
|
|
# value should be negative. If it's positive, then put 0, as no correction
|
|
w_diff = min((1 - self.correct_margin) - (x + half_width), (x - half_width) - self.correct_margin, 0)
|
|
w = bboxes[2] + 2*w_diff
|
|
if w < self.correct_min_bbox_size:
|
|
continue
|
|
h_diff = min((1 - self.correct_margin) - (y + half_height), ((y - half_height) - self.correct_margin), 0)
|
|
h = bboxes[3] + 2 * h_diff
|
|
if h < self.correct_min_bbox_size:
|
|
continue
|
|
res.append([x, y, w, h, bboxes[4]])
|
|
return res
|
|
pass
|
|
|
|
def augment_inner(self, img_ann: ImageLabel) -> [ImageLabel]:
|
|
results = []
|
|
labels = self.correct_bboxes(img_ann.labels)
|
|
if len(labels) == 0 and len(img_ann.labels) != 0:
|
|
print('no labels but was!!!')
|
|
results.append(ImageLabel(
|
|
image=img_ann.image,
|
|
labels=img_ann.labels,
|
|
image_path=os.path.join(constants.config.processed_images_dir, Path(img_ann.image_path).name),
|
|
labels_path=os.path.join(constants.config.processed_labels_dir, Path(img_ann.labels_path).name)
|
|
)
|
|
)
|
|
for i in range(7):
|
|
try:
|
|
res = self.transform(image=img_ann.image, bboxes=labels)
|
|
path = Path(img_ann.image_path)
|
|
name = f'{path.stem}_{i + 1}'
|
|
img = ImageLabel(
|
|
image=res['image'],
|
|
labels=res['bboxes'],
|
|
image_path=os.path.join(constants.config.processed_images_dir, f'{name}{path.suffix}'),
|
|
labels_path=os.path.join(constants.config.processed_labels_dir, f'{name}.txt')
|
|
)
|
|
results.append(img)
|
|
except Exception as e:
|
|
print(f'Error during transformation: {e}')
|
|
return results
|
|
|
|
def read_labels(self, labels_path) -> [[]]:
|
|
with open(labels_path, 'r') as f:
|
|
rows = f.readlines()
|
|
arr = []
|
|
for row in rows:
|
|
str_coordinates = row.split(' ')
|
|
class_num = str_coordinates.pop(0)
|
|
coordinates = [float(n.replace(',', '.')) for n in str_coordinates]
|
|
# noinspection PyTypeChecker
|
|
coordinates.append(class_num)
|
|
arr.append(coordinates)
|
|
return arr
|
|
|
|
def augment_annotation(self, image_file):
|
|
try:
|
|
image_path = os.path.join(constants.config.data_images_dir, image_file.name)
|
|
labels_path = os.path.join(constants.config.data_labels_dir, f'{Path(str(image_path)).stem}.txt')
|
|
image = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
|
|
|
img_ann = ImageLabel(
|
|
image_path=image_path,
|
|
image=image,
|
|
labels_path=labels_path,
|
|
labels=self.read_labels(labels_path)
|
|
)
|
|
try:
|
|
results = self.augment_inner(img_ann)
|
|
for annotation in results:
|
|
cv2.imencode('.jpg', annotation.image)[1].tofile(annotation.image_path)
|
|
with open(annotation.labels_path, 'w') as f:
|
|
lines = [f'{l[4]} {round(l[0], 5)} {round(l[1], 5)} {round(l[2], 5)} {round(l[3], 5)}\n' for l in
|
|
annotation.labels]
|
|
f.writelines(lines)
|
|
f.close()
|
|
|
|
print(f'{datetime.now():{"%Y-%m-%d %H:%M:%S"}}: {self.total_files_processed + 1}/{self.total_images_to_process} : {image_file.name} has augmented')
|
|
except Exception as e:
|
|
print(e)
|
|
self.total_files_processed += 1
|
|
except Exception as e:
|
|
print(f'Error appeared in thread for {image_file.name}: {e}')
|
|
|
|
def augment_annotations(self, from_scratch=False):
|
|
self.total_files_processed = 0
|
|
|
|
if from_scratch:
|
|
shutil.rmtree(constants.config.processed_dir)
|
|
|
|
os.makedirs(constants.config.processed_images_dir, exist_ok=True)
|
|
os.makedirs(constants.config.processed_labels_dir, exist_ok=True)
|
|
|
|
|
|
processed_images = set(f.name for f in os.scandir(constants.config.processed_images_dir))
|
|
images = []
|
|
with os.scandir(constants.config.data_images_dir) as imd:
|
|
for image_file in imd:
|
|
if image_file.is_file() and image_file.name not in processed_images:
|
|
images.append(image_file)
|
|
self.total_images_to_process = len(images)
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
executor.map(self.augment_annotation, images)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
augmentator = Augmentator()
|
|
while True:
|
|
augmentator.augment_annotations()
|
|
print('All processed, waiting for 5 minutes...')
|
|
time.sleep(300)
|