Why does my Dataset make torch go into an infinite loop?

I want to load video frames to train my network, and to speed things up, I’d like to use multiple threads.

I created a custom Dataset class for this, but when the Dataloader tries to iterate over it, it gets stuck. With the python debugger (pdb) I’ve seen that internally, torch goes into a while True loop and never exits. However, I can successfully iterate over the dataset manually.

I can reproduce this easily by instantiating an iterator of the Dataloader and asking it for the next element. Here is a minimal example of this.

#!/bin/env python3

import sys
import math

import torch
import cv2

class CustomDataset(torch.utils.data.IterableDataset):

    def __init__(self, sources, max_iter=-1):
        super().__init__()

        # Load data
        self.vcap = cv2.VideoCapture(sources[0])

        # Load labels
        self.labelvec = []
        with open(sources[1], "r") as fobj:
            for l in fobj:
                self.labelvec.append(float(l))

        # Store internal parameters

        # Frame size
        self.frame_size = (self.vcap.get(cv2.CAP_PROP_FRAME_HEIGHT),
                           self.vcap.get(cv2.CAP_PROP_FRAME_WIDTH))

        # Find dataset length property
        datum_count = int(self.vcap.get(cv2.CAP_PROP_FRAME_COUNT))
        label_count = len(self.labelvec)
        if max_iter != -1:
            assert max_iter < datum_count, max_iter < label_count
            self.length = max_iter
        else:
            assert datum_count == label_count
            self.length = datum_count

        # Setup iterator bounds
        self.iter_start = 0
        self.iter_end   = self.length-1

        return

    def __del__(self):
        self.vcap.release()
        return

    def __getitem__(self, index):

        if index < 0 or index >= self.length:
            raise IndexError

        # Get frame
        self.vcap.set(cv2.CAP_PROP_POS_FRAMES, index)
        frame = self.vcap.read()[1] 
        # Get grayscale frame
        gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        # Flatten frame
        flat_frame = torch.tensor(gray_frame.tolist())

        # Get label
        label = self.labelvec[index]

        return (flat_frame, label)

    def __len__(self):
        return self.length

    def __iter__(self):
        # Based on __iter__ from the MyIterableDataset example implemented in
        # https://pytorch.org/docs/stable/data.html#multi-process-data-loading

        # get worker info
        worker_info = torch.utils.data.get_worker_info()

        # single-threaded loading
        if worker_info is None:
            iter_start = self.iter_start
            iter_end   = self.iter_end

        # multi-threaded loading
        else: 
            per_worker = int(math.ceil((self.iter_end - self.iter_start) / float(worker_info.num_workers)))
            worker_id  = worker_info.id
            iter_start = self.iter_start + worker_id*per_worker
            iter_end   = min(iter_start+per_worker-1, self.iter_end)

        for ix in range(iter_start, iter_end):
            yield self.__getitem__(ix)


if __name__ == "__main__":

    # Calling parameters
    data_fpath = sys.argv[1]
    labels_fpath = sys.argv[2]

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    max_iter = 50
    batch_size = 10
    num_workers = 3

    dataset = CustomDataset((data_fpath, 
                             labels_fpath),
                            max_iter=max_iter)

    dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size=batch_size,
                                             num_workers=num_workers)

    dataloader_iterator = iter(dataloader)

    element = next(dataloader_iterator)
    
    sys.exit(0)

When I run that, the script hangs indefinitely. When I stop it (e.g. by pressing Control + C with focus on the terminal), I get the following output:

  File "[redacted]/dataloader_infinite_loop.py", line 115, in <module>
    element = next(dataloader_iterator)
  File "/usr/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 435, in __next__
    data = self._next_data()
  File "/usr/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1068, in _next_data
    idx, data = self._get_data()
  File "/usr/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1034, in _get_data
    success, data = self._try_get_data()
  File "/usr/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 872, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/usr/lib/python3.9/multiprocessing/queues.py", line 113, in get
    if not self._poll(timeout):
  File "/usr/lib/python3.9/multiprocessing/connection.py", line 262, in poll
    return self._poll(timeout)
  File "/usr/lib/python3.9/multiprocessing/connection.py", line 429, in _poll
    r = wait([self], timeout)
  File "/usr/lib/python3.9/multiprocessing/connection.py", line 936, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.9/selectors.py", line 416, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt

If you then browse the source code for File "/usr/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1034, you’ll see:

1033             while True:
1034                 success, data = self._try_get_data()
1035                 if success:
1036                     return data
1037 

Which is where torch gets stuck.

I’m not very confident that anybody will pin this down, but here goes a try.

I’m not sure why it hangs, but as a first debugging step I would remove the OpenCV VideoCapture object and see, if it could interact badly with the current DataLoader and could cause the hang.

Thanks for your reply @ptrblck. That’s a very good suggestion. I created a DummyVideoCapture class to trick the script and that worked without problems, which means that the issue is indeed in the interaction between the OpenCV VideoCapture object and PyTorch.

Also, I found this article which in one of the update points suggests that OpenCV VideoCapture may cause multithreaded reading to become stuck:

Use Thread in multithreading module instead of Process in muliprocessing module. When using Process and importing cv2.VideoCapture() as input argument, you may suffer reading stuck during retrieving frames from video stream. It seems to be related to the multi-processing mechanism in Python.

I’ll resign to using a custom class to create my batches and not using multithreaded reading yet.