I was trying to generate some images using code and vary it by using randomness. I found that when I use DataLoader and then use enumerate or iter on it, then memory overshoots unexpectedly. On debugging I found that
__getitem__ function is getting called right when I create an enumerate or iter object.
Q: How can I ensure that it only gets called when I call next for iterator or extract next element from enumerate?
For reference here is my toy example (to expose that issue is with calling of
__getitem__ (tested on Python3.7 and 3.8)
#Testing enumeration here from torch.utils.data import Dataset, DataLoader import torch class test_enum(Dataset): def __init__(self, root="/home/aknirala/data/clocks/", transform=None): self.transform = transform self.root = root def __len__(self): return 1000000 def __getitem__(self, idx): print(str(idx) + ", ") return idx dataset = test_enum() dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2) print("Len of iterations per epochs: ",len(dataloader))
Now moment I do:
data_iter = iter(dataloader)
I get something like: (There are batch_sizenum_workers2 numbers printed below)
143666, 315805, 357128, 851433, 298989, 706179, 432269, 11821, 995092, 911574, 812109, 598460, 276834, 563535, 551306, 343008,
Also instead if I do:
e = enumerate(dataloader)
I get something like:
640834, 336805, 792848, 183726, 727800, 322372, 703700, 870395, 856172, 370866, 792898, 415834, 700454, 516118, 916330, 733028,
Later when I call next(data_iter) or run a for loop on e, each time I see 4 numbers which is as expected (as batch_size is s4). However, why does getitem gets called while I am creating enumerate or iter object? And how can I stop it?