What’s the pytorch way of dataloader to introduce specific augmentations at specific time/epoch during training? Can we pass any parameter we want to __getittem__()
of the dataloader (e.g. epoch)?
Yes, you could pass more arguments to __getitem__
, but would need to feed them from the sampler.
You could change some internal attributes of the Dataset
before starting each new epoch.
During the epoch the DataLoader
will create copies of the Dataset
for each worker, so changing internal attributes won’t be reflected in all copies.
Hi @ptrblck, if I understood you correctly something like the following MWE is what you’re describing?
class MyDataset(Dataset):
def __init__(self, *args, epoch=0):
self.epoch = epoch
def __getittem__(self, index):
x, y = self.data[index]
if self.epoch == epoch:
x = transforms(x)
return x, y
# then during training
for epoch in range(epochs):
dataset.epoch = epoch
for x, y in loader:
# train as usual
Yes, with a minor change: since you’ve wrapped the Dataset
in a DataLoader
, you should be using loader.dataset.epoch = epoch
in the outer loop.
1 Like