The torch.utils.data.DataLoad can only provide data batch of one epoch. How could I reset it before it accomplish one epoch so that it will not raise a stopIteration
error unexpectedly.
If I understand you correctly, you want to infinitly loop over your dataloader until a breaking condiction is matched?
You could do something like this (assuming your Dataloader instance is stored in variable loader
):
loader_iter = iter(loader)
while True:
try:
current_batch = next(loader_iter)
except StopIteration:
loader_iter = iter(loader)
current_batch = next(loader_iter)
... # your following data processing steps
if breaking_condition:
break
I tried this method, it worked well but I am afraid this might have some multi-process problems ?
File "train_net.py", line 16, in main
train(args.cfg)
File "/data/zhangzy/learn-projects/Segmentatron/SegNet/train.py", line 90, in train
im, label = next(trainiter)
File "/home/zhangzy/.local/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line
280, in __next__
idx, batch = self._get_batch()
File "/home/zhangzy/.local/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line
259, in _get_batch
return self.data_queue.get()
File "/usr/lib/python2.7/multiprocessing/queues.py", line 378, in get
return recv()
File "/home/zhangzy/.local/lib/python2.7/site-packages/torch/multiprocessing/queue.py", line
22, in recv
return pickle.loads(buf)
File "/usr/lib/python2.7/pickle.py", line 1388, in loads
return Unpickler(file).load()
File "/usr/lib/python2.7/pickle.py", line 864, in load
dispatch[key](self)
File "/usr/lib/python2.7/pickle.py", line 1139, in load_reduce
value = func(*args)
File "/home/zhangzy/.local/lib/python2.7/site-packages/torch/multiprocessing/reductions.py",
line 68, in rebuild_storage_fd
fd = multiprocessing.reduction.rebuild_handle(df)
File "/usr/lib/python2.7/multiprocessing/reduction.py", line 157, in rebuild_handle
new_handle = recv_handle(conn)
File "/usr/lib/python2.7/multiprocessing/reduction.py", line 83, in recv_handle
return _multiprocessing.recvfd(conn.fileno())
OSError: [Errno 4] Interrupted system call
Exception NameError: "global name 'FileNotFoundError' is not defined" in <bound method _DataLo
aderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7f15845d99d0>> ig
nored
If you want to create an ‘infinite generator’ simply place it in a while loop:
def cycle(iterable):
while True:
for x in iterable:
yield x
However there is a more elegant solution with a DataLoader. Simply enumerate over the dataloader when training for one epoch:
data_loader = torch.utils.data.DataLoader(...)
def train(epoch):
for batch_idx, (data, target) in enumerate(data_loader):
# Your code for training one epoch.
# Now perform training for however many epochs you want.
for epoch in range(num_epochs):
train(epoch)
Internally every time you call enumerate(data_loader), it will call the object’s iter method which will create and return a new _DataLoaderIter object, which internally calls iter() on your batch sampler here.
How could I make training cease at certain iter number in this way ?
crossposted: https://stackoverflow.com/questions/60311307/how-does-one-reset-the-dataloader-in-pytorch
if you want to stop after a certain iterations, you could simply use epoch \times batch_idx as an iteration count right ? and then terminate the code once epoch \times batch_idx reaches a certain threshold.
Maybe itertools.cycle
can do this?
from itertools import cycle
infinite_dataloader = cycle(dataloader)
# Just call next(infinit_dataloader)