Yes, i am currently looping through a dict where flair, t1 and the mask are transformed. I plottef the transformed image vs the original and it looked fine. This confuses me alot. Heres the code for transformation:
import numpy as np
import torch.nn.functional as nn_F
import torchvision.transforms.functional as F
import torchvision.transforms as transform
import torch
import nibabel as nib
import matplotlib.pyplot as plt
import math
from collections.abc import Sequence
class Random_Rotation(object):
def __init__(self, angle, interpolation_method = transform.InterpolationMode.BILINEAR):
if not isinstance(angle, (float, int)):
raise TypeError(f"Please provide a valid angle type (int/float), provided {type(self.angle)}")
self.angle = angle
self.interpolation_method = interpolation_method
def __call__(self, image_sample):
current_angle = np.random.uniform(-self.angle, self.angle)
if isinstance(image_sample, list):
unpacked = enumerate(image_sample)
elif isinstance(image_sample, dict):
unpacked = image_sample.items()
else:
raise TypeError(f"Please provide a list or dict of images, given type {type(image_sample)}")
for idx,samples in unpacked:
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
image_sample[idx] = F.rotate(samples, current_angle, interpolation = self.interpolation_method)
return image_sample
class Random_Horizontal_Flip(object):
def __init__(self, flip_probability = 0.5):
if not isinstance(flip_probability, (int,float)):
raise TypeError(f"Please provide a valid value for the flip probability, you typed: {flip_probability}")
else:
pass
if flip_probability > 1.0 or flip_probability < 0:
raise ValueError(f"The probability has to be in the interval[0,1], you typed {flip_probability}")
else:
pass
self.pvalue = flip_probability
def __call__(self,image_sample):
if isinstance(image_sample, list):
unpacked = enumerate(image_sample)
elif isinstance(image_sample, dict):
unpacked = image_sample.items()
else:
raise TypeError(f"Please provide a list or dict of images, given type {type(image_sample)}")
if torch.rand(1) < self.pvalue:
for idx,samples in unpacked:
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
image_sample[idx] = F.hflip(samples)
class Random_Scale(object):
def __init__(self, scale_factor, interpolation_method = transform.InterpolationMode.BILINEAR):
if not isinstance(scale_factor, (int,float)):
raise TypeError(f"Please provide a valid value for the scaling factor, you typed: {scale_factor}")
else:
pass
self.scale_factor = scale_factor
self.interpolation_method = interpolation_method
def __call__(self, image_sample):
current_scale = np.random.uniform(1 - self.scale_factor, 1 + self.scale_factor)
if isinstance(image_sample, list):
unpacked = enumerate(image_sample)
height, width = (image_sample[0]).shape[1:]
elif isinstance(image_sample, dict):
unpacked = image_sample.items()
key = list(image_sample.keys())[0]
height, width = (image_sample[key]).shape[1:]
else:
raise TypeError(f"Please provide a list or dict of images, given type {type(image_sample)}")
#else:
# image_sample = list(image_sample)
if not (height - width) < 1e-6:
raise ValueError("It seems that the height and the width are not equal")
else:
pass
original_img_size = (height, width)
height = int(round(height*current_scale))
width = int(round(width*current_scale))
new_size = (height,width)
for idx,samples in unpacked:
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
else:
pass
image_sample[idx] = F.resize(samples, new_size, self.interpolation_method)
if current_scale < 1.0:
diff = (original_img_size[0] - new_size[0])
left = diff//2
right = left + diff%2
padding_ = (int(left), int(right),)*2
image_sample[idx] = nn_F.pad(image_sample[idx], padding_,"constant",0)
else:
x_min = (new_size[0]- original_img_size[0]) // 2
x_max = x_min + original_img_size[0]
image_sample[idx]= (image_sample[idx])[... , x_min : x_max, x_min : x_max]
return image_sample