Dataloader freezes

I am using PyTorch on a Kuberneted pod running 20.04, some version numbers follow:

  • pytorch
  • pytorch=1.11.0=py3.8_cuda11.3_cudnn8.2.0_0
  • pytorch-mutex=1.0=cuda
  • torchaudio=0.11.0=py38_cu113
  • torchmetrics=0.8.2=pyhd8ed1ab_0
  • torchvision=0.12.0=py38_cu113

The dataloader, using four workers, randomly freezes. After interrupting the process, I read these messages:

  File "/MYPACKAGES/encoder_trainers.py", line 271, in compute_predictions
    for batch in dl:
  File "/opt/conda/envs/torch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 530, in __next__
    data = self._next_data()
  File "/opt/conda/envs/torch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1207, in _next_data
    idx, data = self._get_data()
  File "/opt/conda/envs/torch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1163, in _get_data
    success, data = self._try_get_data()
  File "/opt/conda/envs/torch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1011, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/opt/conda/envs/torch/lib/python3.8/queue.py", line 179, in get
    self.not_empty.wait(remaining)
  File "/opt/conda/envs/torch/lib/python3.8/threading.py", line 306, in wait
    gotit = waiter.acquire(True, timeout)
KeyboardInterrupt

This is my custom dataset, with all the prints for debug messages. All the prints are printed, meaning that the flow exits from the dataset and goes to the dataloader. The item on which the dataloader freezes is completely random.

class ImageDataset(Dataset):
    def __init__(self, dataset: pd.DataFrame, img_fld: str, img_transforms=None, n_classes=None, img_size=224, l1normalization=False):
        super().__init__()
        self.n_classes = n_classes or dataset.shape[1]
        assert self.n_classes == dataset.shape[1]
        print("IMAGE DATASET, NCLASSES SET TO", self.n_classes)        
        self.ds = dataset
        self.img_fld = img_fld
        self.transforms = img_transforms
        self.img_size = img_size
        self.l1normalization = l1normalization
        if l1normalization:
            print("IMAGE DATASET, WARNING: l1 normalization of target values set to TRUE")
        else:
            print("IMAGE DATASET: l1 normalization of target values set to FALSE")

    def __len__(self):
        return len(self.ds)
    #< len

    def __getitem__(self, idx):
        print("Accessing item:", idx)
        assert (idx >=0) and (idx < len(self.ds))
        item = self.ds.iloc[idx]
        filename = item.name
        labels = item.values
        print("\t filename:", filename)
        print("\t labels:", sum(labels))

        if self.l1normalization:
            labels = labels / labels.sum()
        
        image = self.load_image(filename)
        print("image loaded")
        labels = torch.tensor(labels.astype(np.float32))
        print('about to return image and labels')
        return image, labels 
    #< getitem

    def load_image(self, img_filename):
        fn = join(self.img_fld, img_filename)
        img = Image.open(fn)
        print("ok, read image:", img_filename)
        if self.transforms is not None:
            img = self.transforms(img)
        print("returning image")

        return img
    #< load_image
#< ImageDataset

The Dataset is managed by a dataloader created as it follows:

DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4, drop_last=[False, False, False], pin_memory=False)

As a further note, the dataloaders for train, validation and test are kept in a dictionary:

dataloaders = {'train': DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=4, drop_last=[False, False, False], pin_memory=False), 'valid': DataLoader(valid_dataset, batch_size=128, shuffle=False, num_workers=4, drop_last=[False, False, False], pin_memory=False), 'test': DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4, drop_last=[False, False, False], pin_memory=False)}

In one run these are the last messages, then it got frozen:

Accessing item: 1151
         filename: CXR730_IM-2290-1002.png
         labels: 1
ok, read image: CXR730_IM-2290-1002.png
returning image
image loaded
about to return image and labels

in another run:

Accessing item: 767
         filename: CXR61_IM-2197-2001.png
         labels: 1
ok, read image: CXR61_IM-2197-2001.png
returning image, shape: torch.Size([3, 224, 224])
image loaded
about to return image and labels

This freezing started when I reorganized my code. moving from simple Python scripts to a set of (base) classes and subclasses. The dataloader is always managed in a base class, that is never directly instantiated, only subclasses are.

Other info:

  • the freezing occurs more often with smaller batch sizes, e.g. 32, than with larger ones, e.g. 128.
  • it occurs very often, but not always, on the same image (1151 above). The problem is not related to the specific image file because It is read, processed, and resized.

It never freezes with dataloaders created with num_workers=0.