Num_workers > 0 works for some models but not others

I followed the tutorial on transfer learning at https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#further-learning

This all runs (tho I do have some questions about it, Questions after following transfer learning tutorial), with num_workers=6 in the dataloaders.

Now I am trying to create dataloaders that will give a pair of images; the x is an image with random noise, the y is the original clean image (I would like to train a net that produces a clean image from a noisy one).

It works for num_workers=0. Unfortunately, whenever I use num_workers > 0 I get an error. From searching for a solution, it seems to be quite common. But I have not found anything that makes it work. Is it maybe because I am trying to return a pair of images, rather than an image and a class label? My code is below, along with the error.

In case it matters, the code is run in a Jupyter notebook. I have tried with batch_size=1 just to make sure it is not an out of memory issue (6GB video RAM, more than enough). If there is any further info that would help, let me know and I will add, thanks.

imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
data_dir = 'random_images'
batch_size = 1

transf = {
    'trn': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip()
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224)
    ]),
}

class NoisyDataset(Dataset):
    def __init__(self, main_dir, transform):
        self.main_dir = main_dir
        self.transform = transform
        all_imgs = os.listdir(main_dir)
        self.total_imgs = natsort.natsorted(all_imgs)

    def __len__(self):
        return len(self.total_imgs)

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
        image = Image.open(img_loc).convert("RGB")
        image = self.transform(image)
        label = image.copy()
        image = np.array(image) 
        image = random_noise(image, mode='gaussian', mean=0, var=0.05, clip=True)
        image = transforms.ToTensor()(image)
        image = transforms.Normalize(*imagenet_stats)(image)
        label = transforms.ToTensor()(label)
        label = transforms.Normalize(*imagenet_stats)(label)
        return image, label
    
t_ds = NoisyDataset(os.path.join(data_dir, 'train'), transform=transf['trn'])
v_ds = NoisyDataset(os.path.join(data_dir, 'train'), transform=transf['val'])

t_dl = DataLoader(t_ds , batch_size=batch_size, shuffle=True, 
                              num_workers=6, drop_last=False)
v_dl = DataLoader(v_ds , batch_size=2*batch_size, shuffle=False, 
                              num_workers=6, drop_last=False)

# just to check if it works - for num_workers=0 I get
# 0 [<class 'torch.Tensor'>, <class 'torch.Tensor'>]
# 1 [<class 'torch.Tensor'>, <class 'torch.Tensor'>]
# 2 [<class 'torch.Tensor'>, <class 'torch.Tensor'>]
# ... as expected
for idx, batch in enumerate(v_dl):
    print(idx, [type(i) for i in batch])


Empty Traceback (most recent call last)
~\Anaconda3\envs\dsai\lib\site-packages\torch\utils\data\dataloader.py in _try_get_data(self, timeout)
778 try:
–> 779 data = self._data_queue.get(timeout=timeout)
780 return (True, data)

~\Anaconda3\envs\dsai\lib\multiprocessing\queues.py in get(self, block, timeout)
104 if not self._poll(timeout):
–> 105 raise Empty
106 elif not self._poll():

Empty:

During handling of the above exception, another exception occurred:

RuntimeError Traceback (most recent call last)
in
----> 1 for idx, img in enumerate(v_dl):
2 print(idx, type(img))

~\Anaconda3\envs\dsai\lib\site-packages\torch\utils\data\dataloader.py in next(self)
361
362 def next(self):
–> 363 data = self._next_data()
364 self._num_yielded += 1
365 if self._dataset_kind == _DatasetKind.Iterable and \

~\Anaconda3\envs\dsai\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
972
973 assert not self._shutdown and self._tasks_outstanding > 0
–> 974 idx, data = self._get_data()
975 self._tasks_outstanding -= 1
976

~\Anaconda3\envs\dsai\lib\site-packages\torch\utils\data\dataloader.py in _get_data(self)
939 else:
940 while True:
–> 941 success, data = self._try_get_data()
942 if success:
943 return data

~\Anaconda3\envs\dsai\lib\site-packages\torch\utils\data\dataloader.py in _try_get_data(self, timeout)
790 if len(failed_workers) > 0:
791 pids_str = ', '.join(str(w.pid) for w in failed_workers)
–> 792 raise RuntimeError(‘DataLoader worker (pid(s) {}) exited unexpectedly’.format(pids_str))
793 if isinstance(e, queue.Empty):
794 return (False, None)

RuntimeError: DataLoader worker (pid(s) 11752, 8848, 6272, 920, 8248, 11444) exited unexpectedly