I am trying to run a training with 3D patches of shape [1, 112, 112, 192]
and I have been trying several combinations of batch_size
and num_workers
. According to my experience with this dataset and a very similar network, I reckon a batch_size=2
should work memory-wise with the GPU I am using (NVIDIA P100 16GB). CPU-wise, I am using a 2-core Intel Xeon CPU @ 2.30GHz with 27GB RAM. I do not fully understand how num_workers
works, but if I go too high I get a recurrent error in these online forums, something like:
Epoch: 0
Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 761, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "/usr/lib/python3.6/queue.py", line 173, in get
self.not_empty.wait(remaining)
File "/usr/lib/python3.6/threading.py", line 299, in wait
gotit = waiter.acquire(True, timeout)
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
_error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 1476) is killed by signal: Killed.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "sv_unet/main.py", line 246, in <module>
training(unet, train_loader, valid_loader, epochs=epochs, batch_size=BATCH_SIZE, device=device, fold_dir=fold_dir, restore=RESTORE)
File "/content/drive/My Drive/TFM/monai_unet/sv_unet/training.py", line 119, in training
for batch in train_loader:
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 345, in __next__
data = self._next_data()
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 841, in _next_data
idx, data = self._get_data()
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 798, in _get_data
success, data = self._try_get_data()
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 774, in _try_get_data
raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
RuntimeError: DataLoader worker (pid(s) 1472, 1474, 1476, 1478) exited unexpectedly
I read that this should be fixed if I lower num_workers
enough, which I have. But the I get this kind of error:
Epoch: 0
Batch 0
Batch 1
Batch 2
Batch 3
Batch 4
Batch 5
Traceback (most recent call last):
File "sv_unet/main.py", line 246, in <module>
training(unet, train_loader, valid_loader, epochs=epochs, batch_size=BATCH_SIZE, device=device, fold_dir=fold_dir, restore=RESTORE)
File "/content/drive/My Drive/TFM/monai_unet/sv_unet/training.py", line 119, in training
for batch in train_loader:
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 345, in __next__
data = self._next_data()
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 856, in _next_data
return self._process_data(data)
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 881, in _process_data
data.reraise()
File "/usr/local/lib/python3.6/dist-packages/torch/_utils.py", line 395, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
return [default_collate(samples) for samples in transposed]
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
return [default_collate(samples) for samples in transposed]
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
return torch.stack(batch, 0, out=out)
File "/usr/local/lib/python3.6/dist-packages/apex/amp/wrap.py", line 89, in wrapper
return orig_fn(seq, *args, **kwargs)
RuntimeError: Expected object of scalar type double but got scalar type float for sequence element 1.
Which I do not understand and have not found how to fix. I have tried every num_workers
from 0 to 4. Notice how in this last case (I believe I used num_workers=2
here) it does some batches and then crashes. Any ideas on what is going on and how to have it working with a batch_size=2
?