Introducing augmentations at specific time during training?

What’s the pytorch way of dataloader to introduce specific augmentations at specific time/epoch during training? Can we pass any parameter we want to __getittem__() of the dataloader (e.g. epoch)?

Yes, you could pass more arguments to __getitem__, but would need to feed them from the sampler.
You could change some internal attributes of the Dataset before starting each new epoch.
During the epoch the DataLoader will create copies of the Dataset for each worker, so changing internal attributes won’t be reflected in all copies.

Hi @ptrblck, if I understood you correctly something like the following MWE is what you’re describing?

class MyDataset(Dataset):
     def __init__(self, *args, epoch=0):
         self.epoch = epoch

     def __getittem__(self, index):
         x, y = self.data[index]
         if self.epoch == epoch:
            x = transforms(x)
         return x, y

# then during training
for epoch in range(epochs):
     dataset.epoch = epoch
     for x, y in loader:
          # train as usual

Yes, with a minor change: since you’ve wrapped the Dataset in a DataLoader, you should be using loader.dataset.epoch = epoch in the outer loop.

1 Like