How to apply augmentation to dataset

I’d like to augment the training and validation dataset that I currently have. I’m not sure whereabouts to put this code in the main code:

transforms.Compose([
transforms.Resize((229,229)),
transforms.RandomResizedCrop((229,229)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[unknown],
std=[unknown])
])

MAIN CODE:

class STLData(Dataset):
    def __init__(self,trn_val_tst = 0, transform=None):
        data = np.load('hw3.npz')
        if trn_val_tst == 0:
            #trainloader
            self.images = data['arr_0']
            self.labels = data['arr_1']
        elif trn_val_tst == 1:
            #valloader
            self.images = data['arr_2']
            self.labels = data['arr_3']
        else:
            #testloader
            self.images = data['arr_4']
            self.labels = data['arr_5']
            
        self.images = np.float32(self.images)/1.0
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
   
        sample = self.images[idx,:]
        labels = self.labels[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample, labels

train_set = STLData(trn_val_tst=0, transform=torchvision.transforms.ToTensor()) 
val_set = STLData(trn_val_tst=1, transform=torchvision.transforms.ToTensor()) 
test_set = STLData(trn_val_tst=2, transform=torchvision.transforms.ToTensor()) 

batch_size = 100 
n_workers = multiprocessing.cpu_count()
trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=n_workers)
valloader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=n_workers)
testloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=n_workers)

You could pass the transforms.Compose as the transform argument to the STLData, if you would like to use it for this dataset.
Note that torchvision.transforms work on PIL.Images by default (in the nightly more transformation can also be applied to tensors directly) so you might need to transform the numpy arrays to PIL.Images first.

1 Like