I was trying to generate some images using code and vary it by using randomness. I found that when I use DataLoader and then use enumerate or iter on it, then memory overshoots unexpectedly. On debugging I found that __getitem__
function is getting called right when I create an enumerate or iter object.
Q: How can I ensure that it only gets called when I call next for iterator or extract next element from enumerate?
For reference here is my toy example (to expose that issue is with calling of __getitem__
(tested on Python3.7 and 3.8)
#Testing enumeration here
from torch.utils.data import Dataset, DataLoader
import torch
class test_enum(Dataset):
def __init__(self, root="/home/aknirala/data/clocks/", transform=None):
self.transform = transform
self.root = root
def __len__(self):
return 1000000
def __getitem__(self, idx):
print(str(idx) + ", ")
return idx
dataset = test_enum()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4,
shuffle=True, num_workers=2)
print("Len of iterations per epochs: ",len(dataloader))
Now moment I do:
data_iter = iter(dataloader)
I get something like: (There are batch_sizenum_workers2 numbers printed below)
143666, 315805,
357128, 851433,
298989, 706179,
432269, 11821,
995092,
911574,
812109,
598460,
276834,
563535,
551306,
343008,
Also instead if I do:
e = enumerate(dataloader)
I get something like:
640834, 336805,
792848, 183726,
727800, 322372,
703700,
870395,
856172,
370866,
792898,
415834,
700454,
516118,
916330,
733028,
Later when I call next(data_iter) or run a for loop on e, each time I see 4 numbers which is as expected (as batch_size is s4). However, why does getitem gets called while I am creating enumerate or iter object? And how can I stop it?