Files
ai-training/preprocessing-cuda.py
T
zxsanny b5e5f0b297 correct albumentation
try to make augmentation on GPU.
saved llm prompt
2025-03-05 10:45:41 +02:00

171 lines
5.9 KiB
Python

import os
import time
import numpy as np
import cv2
from pathlib import Path
import concurrent.futures
import nvidia.dali as dali
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from constants import (
data_images_dir,
data_labels_dir,
processed_images_dir,
processed_labels_dir
)
# Configurable number of augmentations per image
NUM_AUGMENTATIONS = 7
class DataLoader:
def __init__(self, batch_size=32):
self.batch_size = batch_size
os.makedirs(processed_images_dir, exist_ok=True)
os.makedirs(processed_labels_dir, exist_ok=True)
def _read_labels(self, labels_path):
with open(labels_path, 'r') as f:
rows = f.readlines()
arr = []
for row in rows:
str_coordinates = row.split(' ')
class_num = str_coordinates.pop(0)
coordinates = [float(n.replace(',', '.')) for n in str_coordinates]
coordinates.append(class_num)
arr.append(coordinates)
return arr
def _get_image_label_pairs(self):
processed_images = set(f.name for f in os.scandir(processed_images_dir))
pairs = []
for image_file in os.scandir(data_images_dir):
if image_file.is_file() and image_file.name not in processed_images:
image_path = os.path.join(data_images_dir, image_file.name)
labels_path = os.path.join(data_labels_dir, f'{Path(image_path).stem}.txt')
if os.path.exists(labels_path):
pairs.append((image_path, labels_path))
return pairs
def create_dali_pipeline(self, file_paths):
@dali.pipeline_def(batch_size=self.batch_size, num_threads=32, device_id=0)
def augmentation_pipeline():
# Read images
jpegs, _ = fn.file_reader(file_root=data_images_dir, file_list=file_paths, random_shuffle=False)
# Decode images
images = fn.decoders.image(jpegs, device='mixed')
# Random augmentations with GPU acceleration
augmented_images = []
for _ in range(NUM_AUGMENTATIONS):
aug_image = fn.random_resized_crop(images, random_area=(0.8, 1.0))
# Apply multiple random augmentations
aug_image = fn.flip(aug_image, horizontal=fn.random.coin_flip())
aug_image = fn.brightness_contrast(
aug_image,
brightness=fn.random.uniform(range=(-0.05, 0.05)),
contrast=fn.random.uniform(range=(-0.05, 0.05))
)
aug_image = fn.rotate(
aug_image,
angle=fn.random.uniform(range=(-25, 25)),
fill_value=0
)
# Add noise and color jittering
aug_image = fn.noise.gaussian(aug_image, mean=0, stddev=fn.random.uniform(range=(0, 0.1)))
aug_image = fn.hsv(
aug_image,
hue=fn.random.uniform(range=(-8, 8)),
saturation=fn.random.uniform(range=(-8, 8)),
value=fn.random.uniform(range=(-8, 8))
)
augmented_images.append(aug_image)
# Also include original image
augmented_images.append(images)
return augmented_images
return augmentation_pipeline()
def process_batch(self):
image_label_pairs = self._get_image_label_pairs()
# Create file list for DALI
file_list_path = os.path.join(processed_images_dir, 'file_list.txt')
with open(file_list_path, 'w') as f:
for img_path, _ in image_label_pairs:
f.write(f'{img_path}\n')
# Create DALI pipeline
pipeline = self.create_dali_pipeline(file_list_path)
pipeline.build()
# Process images
for batch_idx in range(0, len(image_label_pairs), self.batch_size):
batch_pairs = image_label_pairs[batch_idx:batch_idx + self.batch_size]
pipeline.run()
# Get augmented images
for img_idx, (orig_img_path, orig_labels_path) in enumerate(batch_pairs):
# Read original labels
orig_labels = self._read_labels(orig_labels_path)
# Write original image and labels
self._write_image_and_labels(
pipeline.output[NUM_AUGMENTATIONS][img_idx],
orig_img_path,
orig_labels,
is_original=True
)
# Write augmented images
for aug_idx in range(NUM_AUGMENTATIONS):
self._write_image_and_labels(
pipeline.output[aug_idx][img_idx],
orig_img_path,
orig_labels,
aug_idx=aug_idx
)
def _write_image_and_labels(self, image, orig_img_path, labels, is_original=False, aug_idx=None):
path = Path(orig_img_path)
if is_original:
img_name = path.name
label_name = f'{path.stem}.txt'
else:
img_name = f'{path.stem}_{aug_idx + 1}{path.suffix}'
label_name = f'{path.stem}_{aug_idx + 1}.txt'
# Write image
img_path = os.path.join(processed_images_dir, img_name)
cv2.imencode('.jpg', image.asnumpy())[1].tofile(img_path)
# Write labels
label_path = os.path.join(processed_labels_dir, label_name)
with open(label_path, 'w') as f:
lines = [f'{ann[4]} {round(ann[0], 5)} {round(ann[1], 5)} {round(ann[2], 5)} {round(ann[3], 5)}\n' for ann in labels]
f.writelines(lines)
def main():
while True:
loader = DataLoader()
loader.process_batch()
print('All processed, waiting for 2 minutes...')
time.sleep(120)
if __name__ == '__main__':
main()