I have three types of custom augmentations to be performed on the MNIST(written three different functions for the same). How do I do create a data loader comprising of augmented data? The method I’m currently using throws an error for augmentation functions that accept a parameter.
# the augmentation function
def pad_rotate_project(entry, theta, show = False, requires_theta = False):
if (type(entry) != torch.Tensor):
entry = torch.tensor(entry.T)
copy = entry.clone()
if len(entry.shape) == 2:
copy = copy.unsqueeze(dim = 0)
copy = F.pad(copy, pad=(1,1,1,1), value=0)
rot = transforms.RandomRotation(degrees=(theta, theta))
rotated = rot(transforms.ToPILImage()(copy))
plt.imshow(rotated)
tensor = transforms.ToTensor()(rotated)
sum_axis = torch.sum(tensor, dim = 2)
# If this is true, then we append the rotation angle applied to the output
if requires_theta:
to_add = np.full(shape=copy.size()[0], fill_value=theta)
rot_tens = torch.tensor(to_add, dtype=torch.float).unsqueeze(dim = 1)
sum_axis = torch.cat([sum_axis, rot_tens], dim = 1)
return sum_axis
# class to create custom dataset. I plan to create 3 custom datasets and then concatenate them
class pad_rotate_project_data(data.Dataset):
def __init__(self, X_data, theta = 100, transform2=None):
self.X_data = X_data
self.theta = theta
self.transform2 = transform2
def __getitem__(self, index):
img = self.X_data[index][0]
label = self.X_data[index][1]
#perform augmentation
if self.transform2:
img = self.transform2(img, **{'theta' : self.theta})
return img,label
def __len__(self):
return len(self.X_data)
# part where i apply transformation
ds2 = convolve_noise_data(mnist_train, transform2 = pad_rotate_project)
train_loader2 = data.DataLoader(ds2, batch_size,
#sampler = RandomSampler(train_ds),
num_workers = THREADS,
pin_memory= USE_CUDA )
Below is the error I face: ---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-58-2e42ff52d82a> in <module>()
----> 1 ds2 = convolve_noise_data(mnist_train, transform2 = pad_rotate_project)
2
3 train_loader2 = data.DataLoader(ds2, batch_size,
4 #sampler = RandomSampler(train_ds),
5 num_workers = THREADS,
TypeError: __init__() got an unexpected keyword argument 'transform2'