Transform and Image Data Augmentation

The documentation for torchvision.transform seems to be not clear enough. There are several questions I have.

  1. Does Compose apply each transform to every image sequentially. If order matters, what if I want to don’t want to apply transform in a composite way? (i.e. if I want to apply either flipping and then normalization or cropping followed by normalization for every image?)
  2. How do I know the number of extra images gained due to augmentation? Does number of transform in Compose determine the amount of augmentation since, in my case, I want heavy augmentation as I only have 10 samples per class.
  1. Yes, the order of transformations will stay the same, if you don’t use transforms.RandomOrder or manipulate the list in another way. You could use transforms.RandomApply or RandomChoice it that fits your use case. Otherwise just add a condition to switch between both approaches (e.g. in your Dataset's __getitem__).
  2. Each image will be transformed randomly on-the-fly so no images will be generated and the length of your Dataset stays the same.
2 Likes

For 2, what if I want more augmented images ?

Each iteration will transform your images, so you might just either use more epochs or artificially increase the length of your datset by setting the desired lengths in its __len__ method and manipulate the index in its __getitem__ with e.g. a modulo operation.
However, both approaches would yield the same result, so I would just increase the number of epochs.

Dear concern, I would like to increase the number of data by torchvision.transform.compose but instead of increasing the number of epochs I would like to go for the second option proposed by @ptrblck in the latest reply.

I have changed the length of my data in __len__ by multiplying by 5 (now I have 5 times larger number of training data) but I do not see how to change the index in its __getitem__.

I am not sure to modify this part of the index. An example of this modification would be highly appreciated.

class PatientDataset(Dataset):
    def __init__(self, file_names, to_augment=False, transform=None, mode='train', problem_type=None):
        self.file_names = file_names
        self.to_augment = to_augment
        self.transform = transform
        self.mode = mode
        self.problem_type = problem_type

    def __len__(self):
        return 5*len(self.file_names)

    def __getitem__(self, idx):
        img_file_name = self.file_names[idx]
        image = load_image(img_file_name)
        mask = load_mask(img_file_name, self.problem_type)

        data = {"image": image, "mask": mask}
        
        augmented = self.transform(**data)
        image, mask = augmented["image"], augmented["mask"]

        if self.mode == 'train':
            if self.problem_type == 'binary':
                return img_to_tensor(image), torch.from_numpy(np.expand_dims(mask, 0)).float()
            else:
                return img_to_tensor(image), torch.from_numpy(mask).long()
        else:
            return img_to_tensor(image), str(img_file_name)

The mentioned modulo should work:

class PatientDataset(Dataset):
    def __init__(self):
        self.data = torch.arange(10).float().view(10, 1)

    def __len__(self):
        return 5*len(self.data)

    def __getitem__(self, idx):
        idx = idx % len(self.data)
        x = self.data[idx]
        return x

dataset = PatientDataset()
loader = DataLoader(dataset, batch_size=2)

for x in loader:
    print(x)
1 Like

It worked. Thank you very much.

From what I understand, a new transformed image will be generated in every epoch. Is it possible to use the same transformed images generated in the first epoch i.e. apply pytorch transformations just once?

You could iterate the Dataset containing the random transformations once and store all augmented tensors. Once this is done you could then create a TensorDataset which only returns the already transformed tensors.
Let me know, if that would work for you.

2 Likes

I think this suggestion by @ptrblck would work for me. I’m relatively new to Pytorch and not exactly sure what it would look like to iterate the dataset and store the augmented tensors. @Sameer_Verma would you be willing to post the code you used to do this?

1 Like