I am trying to train a convolutional network using images of variable size. For this purpose I use DataLoader
with custom collate_fn
function.
class ImagesFromList(data.Dataset):
def __init__(self, images):
self.images_fn = images
def __getitem__(self, index):
global images
file1 = images[self.images_fn[index][0]]
file2 = images[self.images_fn[index][1]]
val = self.images_fn[index][2]
files = [file1, file2]
return files, val
def __len__(self):
return len(self.images_fn)
loader_train = torch.utils.data.DataLoader(
ImagesFromList(images=trainset),
batch_size=1, shuffle=True, num_workers=1, pin_memory=True, collate_fn = my_collate
)
It works when I use the following my_collate
:
def my_collate(batch):
elem_type = type(batch[0])
if isinstance(batch[0], torch.Tensor):
return torch.stack(batch, 0, out=None)
elif isinstance(batch[0], int_classes):
return torch.LongTensor(batch)
elif isinstance(batch[0], collections.Sequence):
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
which is just a little modified version of default_collate
. But it fails when batch_size > 1
due to torch.stack
. Because of this I use the following my_collate
:
def my_collate(batch):
elem_type = type(batch[0])
if isinstance(batch[0], torch.Tensor):
return batch
elif isinstance(batch[0], int_classes):
return torch.LongTensor(batch)
elif isinstance(batch[0], collections.Sequence):
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
And in this case the simple code
for i, inp in enumerate(loader_train):
pass
fails with the error:
Exception in thread Thread-14:
Traceback (most recent call last):
File "/usr/lib/python3.5/threading.py", line 914, in _bootstrap_inner
self.run()
File "/usr/lib/python3.5/threading.py", line 862, in run
self._target(*self._args, **self._kwargs)
File "/home/stanismorozov/.local/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 71, in _worker_manager_loop
r = in_queue.get()
File "/usr/lib/python3.5/multiprocessing/queues.py", line 345, in get
return ForkingPickler.loads(res)
File "/home/stanismorozov/.local/lib/python3.5/site-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
fd = df.detach()
File "/usr/lib/python3.5/multiprocessing/resource_sharer.py", line 58, in detach
return reduction.recv_handle(conn)
File "/usr/lib/python3.5/multiprocessing/reduction.py", line 181, in recv_handle
return recvfds(s, 1)[0]
File "/usr/lib/python3.5/multiprocessing/reduction.py", line 154, in recvfds
raise EOFError
EOFError
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-55-3d1f5f03c1e9> in <module>()
----> 1 for i, inp in enumerate(loader_train):
2 pass
~/.local/lib/python3.5/site-packages/torch/utils/data/dataloader.py in __next__(self)
278 while True:
279 assert (not self.shutdown and self.batches_outstanding > 0)
--> 280 idx, batch = self._get_batch()
281 self.batches_outstanding -= 1
282 if idx != self.rcvd_idx:
~/.local/lib/python3.5/site-packages/torch/utils/data/dataloader.py in _get_batch(self)
257 raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
258 else:
--> 259 return self.data_queue.get()
260
261 def __next__(self):
/usr/lib/python3.5/queue.py in get(self, block, timeout)
162 elif timeout is None:
163 while not self._qsize():
--> 164 self.not_empty.wait()
165 elif timeout < 0:
166 raise ValueError("'timeout' must be a non-negative number")
/usr/lib/python3.5/threading.py in wait(self, timeout)
291 try: # restore state no matter what (e.g., KeyboardInterrupt)
292 if timeout is None:
--> 293 waiter.acquire()
294 gotit = True
295 else:
~/.local/lib/python3.5/site-packages/torch/utils/data/dataloader.py in handler(signum, frame)
176 # This following call uses `waitid` with WNOHANG from C side. Therefore,
177 # Python can still get and update the process status successfully.
--> 178 _error_if_any_worker_fails()
179 if previous_handler is not None:
180 previous_handler(signum, frame)
RuntimeError: DataLoader worker (pid 25564) is killed by signal: Aborted.
I don’t understand what’s going on. I’m stuck and don’t know what to do. Thank you for your help and explanation.