Source code for MRIsegm.datagenerators

from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import cv2
import pydicom
import glob
from skimage.restoration import denoise_nl_means, estimate_sigma


__author__ = ['Giuseppe Filitto']
__email__ = ['giuseppe.filitto@studio.unibo.it']


[docs]def create_segmentation_generator(img_path, mask_path, BATCH_SIZE, IMG_SIZE, SEED, data_gen_args_img, data_gen_args_mask): # pragma: no cover ''' Create DataGenerator yielding tuples of (x, y) with shape (batch_size, height, width, channels) where x is the input image and y is the ground-truth. The data generation is performed using data_gen_args_img and data_gen_args_mask. Parameters ---------- img_path : str path for the images directory. mask_path : str path for the ground-truth directory. BATCH_SIZE : int size of the batches of data. IMG_SIZE : tuple (image height, image width). SEED : int seed for randomness control. data_gen_args_img: dict dict of keras ImageDataGenerator args for the generation of custom images. data_gen_args_mask: dict dict of keras ImageDataGenerator args for the generation of custom masks. Returns ------- zip object tuples of (x, y) with shape (batch_size, height, width, channels) where x is the input image and y is the ground-truth. ''' img_data_gen = ImageDataGenerator(**data_gen_args_img) mask_data_gen = ImageDataGenerator(**data_gen_args_mask) img_generator = img_data_gen.flow_from_directory( img_path, target_size=IMG_SIZE, class_mode=None, color_mode='grayscale', batch_size=BATCH_SIZE, seed=SEED) mask_generator = mask_data_gen.flow_from_directory( mask_path, target_size=IMG_SIZE, class_mode=None, color_mode='grayscale', batch_size=BATCH_SIZE, seed=SEED) return zip(img_generator, mask_generator)
[docs]class DataGenerator: def __init__(self, batch_size, source_path, label_path, aug=False, seed=123, validation_split=0., subset='training'): ''' Custom DataGenerator Parameters ---------- batch_size : int batch size source_path : str path of the directory containing dicom files label_path : str path of the directory containing labels images aug : bool, optional data augmentation, by default False seed : int, optional random seed, by default 123 validation_split : optional validation split rate, by default 0. subset : str, optional set training or validation subset of data, by default 'training' ''' np.random.seed(seed) source_files = sorted(glob.glob(source_path + '/*.dcm')) source_files = np.asarray(source_files) labels_files = sorted(glob.glob(label_path + '/*.png')) labels_files = np.asarray(labels_files) assert source_files.size == labels_files.size source_files, labels_files = self.randomize(source_files, labels_files) idx = np.arange(0, source_files.size) np.random.shuffle(idx) self._source_trainfiles = source_files[idx[int(source_files.size * validation_split):]] self._labels_trainfiles = labels_files[idx[int(labels_files.size * validation_split):]] self._source_valfiles = source_files[idx[:int(source_files.size * validation_split)]] self._labels_valfiles = labels_files[idx[:int(labels_files.size * validation_split)]] self.subset = subset if self.subset == 'training': self._num_data = self._source_trainfiles.size elif self.subset == 'validation': self._num_data = self._source_valfiles.size self.aug = aug self._batch = batch_size self._cbatch = 0 self._data, self._label = (None, None) @property def num_data(self): ''' check the number of files for the relative subset Returns ------- int number of files (images) ''' return self._num_data
[docs] def randomize(self, source, label): ''' Shuffle data Parameters ---------- source : array files paths label : array labels paths Returns ------- tuple randomized source and label paths ''' random_index = np.arange(0, source.size) np.random.shuffle(random_index) source = source[random_index] label = label[random_index] return (source, label)
[docs] def resize(self, img, lbl): ''' resize input and labels images Parameters ---------- img : image input image lbl : image label image Returns ------- tuple resized input and label image ''' height, _ = img.shape[0], img.shape[1] if height != 512: img = cv2.resize(img, (512, 512)) lbl = cv2.resize(lbl, (512, 512)) else: img = img lbl = lbl return (img, lbl)
[docs] def crop(self, img, lbl): ''' crop input and labels images Parameters ---------- img : image input image lbl : image label image Returns ------- tuple cropped input and label image ''' height, width = img.shape[0], img.shape[1] if height != 512: img = cv2.resize(img, (512, 512)) lbl = cv2.resize(lbl, (512, 512)) assert img.shape[0] == 512 y, x = 256, 256 dy, dx = y // 2, x // 2 return (img[(y - dy):(y + dy), (x - dx):(x + dx)], lbl[(y - dy):(y + dy), (x - dx):(x + dx)])
[docs] def random_vflip(self, img, lbl): ''' random vertical flip input and label images Parameters ---------- img : image input image lbl : image label image Returns ------- tuple randomly vertical flipped input and label image ''' idx = np.random.uniform(low=0., high=1.) if idx > 0.5: return (cv2.flip(img, 0), cv2.flip(lbl, 0)) else: return (img, lbl)
[docs] def random_hflip(self, img, lbl): ''' random horizontal flip input and label images Parameters ---------- img : image input image lbl : image label image Returns ------- tuple randomly horizontal flipped input and label image ''' idx = np.random.uniform(low=0., high=1.) if idx > 0.5: return (cv2.flip(img, 1), cv2.flip(lbl, 1)) else: return (img, lbl)
[docs] def rescale(self, img): ''' Normalize and rescale image to binary floating 32-bit Parameters ---------- img : image image to be normalized and rescaled Returns ------- image normalaized and rescaled image ''' rescaled = cv2.normalize(img, dst=None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F) return rescaled
[docs] def denoise(self, img): ''' Denoise the image using non-local means algorithm Parameters ---------- img : image image to be denoised Returns ------- image smoothed denoised image ''' patch_kw = dict(patch_size=5, patch_distance=6) sigma_est = np.mean(estimate_sigma(img)) denoised = denoise_nl_means(img, h=10 * sigma_est, sigma=sigma_est, fast_mode=True, **patch_kw) return denoised
[docs] def gamma_correction(self, img, gamma=1.0): ''' Perform gamma correction. The true value of gamma used in the formula is 1/gamma. Parameters ---------- img : image image to be filtered gamma : float, optional gamma value, by default 1.0 Returns ------- image gamma corrected image ''' igamma = 1.0 / gamma imin, imax = img.min(), img.max() img_c = img.copy() img_c = ((img_c - imin) / (imax - imin)) ** igamma img_c = img_c * (imax - imin) + imin return img_c
def __iter__(self): self._cbatch = 0 return self def __next__(self): if self._cbatch + self._batch >= self._num_data: self._cbatch = 0 self._source_trainfiles, self._labels_trainfiles = self.randomize(self._source_trainfiles, self._labels_trainfiles) self._source_valfiles, self._labels_valfiles = self.randomize(self._source_valfiles, self._labels_valfiles) if self.subset == 'training': c_sources = self._source_trainfiles[self._cbatch:self._cbatch + self._batch] c_labels = self._labels_trainfiles[self._cbatch:self._cbatch + self._batch] elif self.subset == 'validation': c_sources = self._source_valfiles[self._cbatch:self._cbatch + self._batch] c_labels = self._labels_valfiles[self._cbatch:self._cbatch + self._batch] # load the data images = [pydicom.dcmread(f).pixel_array for f in c_sources] labels = [cv2.imread(f, 0) for f in c_labels] # check size images, labels = zip(*[self.resize(im, lbl) for im, lbl in zip(images, labels)]) # cast images = [self.rescale(im) for im in images] labels = [self.rescale(lbl) for lbl in labels] # denoise images = [self.denoise(im) for im in images] # gamma correction images = [self.gamma_correction(im, gamma=1.5) for im in images] if self.aug: # random horizontal flip images, labels = zip(*[self.random_hflip(im, lbl) for im, lbl in zip(images, labels)]) # random vertical flip images, labels = zip(*[self.random_vflip(im, lbl) for im, lbl in zip(images, labels)]) images = [im[..., np.newaxis] for im in images] labels = [lbl[..., np.newaxis] for lbl in labels] # to numpy images = np.array(images) labels = np.array(labels) self._cbatch += self._batch return (images, labels)