Files
ai-training/preprocessing-cuda.py
T
2025-03-05 10:47:13 +02:00

171 lines
5.9 KiB
Python

import os
import time
import cv2
from pathlib import Path
import nvidia.dali as dali
import nvidia.dali.fn as fn
from constants import (
data_images_dir,
data_labels_dir,
processed_images_dir,
processed_labels_dir
)
NUM_AUGMENTATIONS = 7
class DataLoader:
def __init__(self, batch_size=32):
self.batch_size = batch_size
os.makedirs(processed_images_dir, exist_ok=True)
os.makedirs(processed_labels_dir, exist_ok=True)
def _read_labels(self, labels_path):
with open(labels_path, 'r') as f:
rows = f.readlines()
arr = []
for row in rows:
str_coordinates = row.split(' ')
class_num = str_coordinates.pop(0)
coordinates = [float(n.replace(',', '.')) for n in str_coordinates]
coordinates.append(class_num)
arr.append(coordinates)
return arr
def _get_image_label_pairs(self):
processed_images = set(f.name for f in os.scandir(processed_images_dir))
pairs = []
for image_file in os.scandir(data_images_dir):
if image_file.is_file() and image_file.name not in processed_images:
image_path = os.path.join(data_images_dir, image_file.name)
labels_path = os.path.join(data_labels_dir, f'{Path(image_path).stem}.txt')
if os.path.exists(labels_path):
pairs.append((image_path, labels_path))
return pairs
def create_dali_pipeline(self, file_paths):
@dali.pipeline_def(batch_size=self.batch_size, num_threads=32, device_id=0)
def augmentation_pipeline():
# Read images
jpegs, _ = fn.file_reader(file_root=data_images_dir, file_list=file_paths, random_shuffle=False)
# Decode images
images = fn.decoders.image(jpegs, device='mixed')
# Random augmentations with GPU acceleration
augmented_images = []
for _ in range(NUM_AUGMENTATIONS):
aug_image = fn.random_resized_crop(
images,
device='gpu',
min_scale=0.8,
max_scale=1.0
)
# Apply multiple random augmentations
aug_image = fn.flip(aug_image, horizontal=fn.random.coin_flip())
aug_image = fn.brightness_contrast(
aug_image,
brightness=fn.random.uniform(range=(-0.05, 0.05)),
contrast=fn.random.uniform(range=(-0.05, 0.05))
)
aug_image = fn.rotate(
aug_image,
angle=fn.random.uniform(range=(-25, 25)),
fill_value=0
)
# Add noise and color jittering
aug_image = fn.noise.gaussian(aug_image, mean=0, stddev=fn.random.uniform(range=(0, 0.1)))
aug_image = fn.hsv(
aug_image,
hue=fn.random.uniform(range=(-8, 8)),
saturation=fn.random.uniform(range=(-8, 8)),
value=fn.random.uniform(range=(-8, 8))
)
augmented_images.append(aug_image)
# Also include original image
augmented_images.append(images)
return tuple(augmented_images)
return augmentation_pipeline()
def process_batch(self):
image_label_pairs = self._get_image_label_pairs()
# Create file list for DALI
file_list_path = os.path.join(processed_images_dir, 'file_list.txt')
with open(file_list_path, 'w') as f:
for img_path, _ in image_label_pairs:
f.write(f'{img_path}\n')
# Create DALI pipeline
pipeline = self.create_dali_pipeline(file_list_path)
pipeline.build()
# Process images
for batch_idx in range(0, len(image_label_pairs), self.batch_size):
batch_pairs = image_label_pairs[batch_idx:batch_idx + self.batch_size]
pipeline.run()
# Get augmented images
for img_idx, (orig_img_path, orig_labels_path) in enumerate(batch_pairs):
# Read original labels
orig_labels = self._read_labels(orig_labels_path)
# Write original image and labels
self._write_image_and_labels(
pipeline.output[NUM_AUGMENTATIONS][img_idx],
orig_img_path,
orig_labels,
is_original=True
)
# Write augmented images
for aug_idx in range(NUM_AUGMENTATIONS):
self._write_image_and_labels(
pipeline.output[aug_idx][img_idx],
orig_img_path,
orig_labels,
aug_idx=aug_idx
)
def _write_image_and_labels(self, image, orig_img_path, labels, is_original=False, aug_idx=None):
path = Path(orig_img_path)
if is_original:
img_name = path.name
label_name = f'{path.stem}.txt'
else:
img_name = f'{path.stem}_{aug_idx + 1}{path.suffix}'
label_name = f'{path.stem}_{aug_idx + 1}.txt'
# Write image
img_path = os.path.join(processed_images_dir, img_name)
cv2.imencode('.jpg', image.asnumpy())[1].tofile(img_path)
# Write labels
label_path = os.path.join(processed_labels_dir, label_name)
with open(label_path, 'w') as f:
lines = [f'{ann[4]} {round(ann[0], 5)} {round(ann[1], 5)} {round(ann[2], 5)} {round(ann[3], 5)}\n' for ann in labels]
f.writelines(lines)
def main():
while True:
loader = DataLoader()
loader.process_batch()
print('All processed, waiting for 2 minutes...')
time.sleep(120)
if __name__ == '__main__':
main()