How could I reset dataloader or count data batch with iter instead of epoch

The torch.utils.data.DataLoad can only provide data batch of one epoch. How could I reset it before it accomplish one epoch so that it will not raise a stopIteration error unexpectedly.

If I understand you correctly, you want to infinitly loop over your dataloader until a breaking condiction is matched?

You could do something like this (assuming your Dataloader instance is stored in variable loader):

loader_iter = iter(loader)

while True:

    try:
        current_batch = next(loader_iter)
    except StopIteration:
        loader_iter = iter(loader)
        current_batch = next(loader_iter)

    ... # your following data processing steps

    if breaking_condition:
        break
3 Likes

I tried this method, it worked well but I am afraid this might have some multi-process problems ?

  File "train_net.py", line 16, in main
    train(args.cfg)
  File "/data/zhangzy/learn-projects/Segmentatron/SegNet/train.py", line 90, in train
    im, label = next(trainiter)
  File "/home/zhangzy/.local/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line
 280, in __next__
    idx, batch = self._get_batch()
  File "/home/zhangzy/.local/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line
 259, in _get_batch
    return self.data_queue.get()
  File "/usr/lib/python2.7/multiprocessing/queues.py", line 378, in get
    return recv()
  File "/home/zhangzy/.local/lib/python2.7/site-packages/torch/multiprocessing/queue.py", line
 22, in recv
    return pickle.loads(buf)
  File "/usr/lib/python2.7/pickle.py", line 1388, in loads
    return Unpickler(file).load()
  File "/usr/lib/python2.7/pickle.py", line 864, in load
    dispatch[key](self)
  File "/usr/lib/python2.7/pickle.py", line 1139, in load_reduce
    value = func(*args)
  File "/home/zhangzy/.local/lib/python2.7/site-packages/torch/multiprocessing/reductions.py",
 line 68, in rebuild_storage_fd
    fd = multiprocessing.reduction.rebuild_handle(df)
  File "/usr/lib/python2.7/multiprocessing/reduction.py", line 157, in rebuild_handle
    new_handle = recv_handle(conn)
  File "/usr/lib/python2.7/multiprocessing/reduction.py", line 83, in recv_handle
    return _multiprocessing.recvfd(conn.fileno())
OSError: [Errno 4] Interrupted system call
Exception NameError: "global name 'FileNotFoundError' is not defined" in <bound method _DataLo
aderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7f15845d99d0>> ig
nored

If you want to create an ‘infinite generator’ simply place it in a while loop:

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

However there is a more elegant solution with a DataLoader. Simply enumerate over the dataloader when training for one epoch:

data_loader = torch.utils.data.DataLoader(...)
def train(epoch):
        for batch_idx, (data, target) in enumerate(data_loader):
            # Your code for training one epoch.

# Now perform training for however many epochs you want.
for epoch in range(num_epochs):
    train(epoch)

Internally every time you call enumerate(data_loader), it will call the object’s iter method which will create and return a new _DataLoaderIter object, which internally calls iter() on your batch sampler here.

5 Likes

How could I make training cease at certain iter number in this way ?

crossposted: https://stackoverflow.com/questions/60311307/how-does-one-reset-the-dataloader-in-pytorch

if you want to stop after a certain iterations, you could simply use epoch \times batch_idx as an iteration count right ? and then terminate the code once epoch \times batch_idx reaches a certain threshold.

1 Like

Maybe itertools.cycle can do this?

from itertools import cycle

infinite_dataloader = cycle(dataloader)

# Just call next(infinit_dataloader)