I have a dataset where I have images and their corresponding labels, but it is not a 1 to 1 relation - one image can have multiple correct labels, so, during training, I want to randomly sample a few labels (ideally different ones for different epochs to help with overfitting).
This is from a python and pytorch noob so take it with a grain of salt, but I think it would work. It solves the problem without exposing the epoch number to the dataset, although I’ll explain how that could be done afterwards.
This assumes that your dataset inherits from Pytorch’s Dataset class.
from torch.utils.data import Dataset.
In the class’s init() method create a cycle object for each image. This will cycle through a list of that image’s labels. Store these objects in a dictionary indexed by the image’s id number.
Every time you access the cycle object for a given image it will return a new label. Eventually it will wrap around to the first label. If you want the order of the labels to be more random you can call random.shuffle() on the list before instantiating the object, or anytime afterwards to re-shuffle the list.
By using a cycle and not just randomly sampling a label directly from a list, you can be sure that ALL labels will be seen in as few epochs as possible (as long as you go through as many epochs as there are labels for an image), without having to wait for each label to eventually be chosen at random, which could take many epochs. It also makes sure that the same label isn’t re-sampled until all of them have been used once.
Now if you did want to pass the epoch number into the dataset, you could do that using a custom method inside the dataset class that sets a class variable. Something like this:
After calling that from the training loop, your dataset will know what epoch the training loop is currently on, and you can do whatever you wish using this new piece of information. Call it at the start of each epoch. A similar approach can be used to do arbitrary things within the dataset while training progresses, such as re-shuffling the lists used by the cycle objects.
Thank you so much for your detailed response! The first method seems interesting, and I will try it out when I have time.
I ended up recreating the dataset object every epoch. I also tried the second method before (I tried to access and modify the dataset in dataloader, but it didn’t work)
I cached the hdf5 file being used, so the speed doesn’t seem to be impacted.