I want to load video frames to train my network, and to speed things up, I’d like to use multiple threads.
I created a custom Dataset class for this, but when the Dataloader tries to iterate over it, it gets stuck. With the python debugger (pdb) I’ve seen that internally, torch goes into a while True
loop and never exits. However, I can successfully iterate over the dataset manually.
I can reproduce this easily by instantiating an iterator of the Dataloader and asking it for the next element. Here is a minimal example of this.
#!/bin/env python3
import sys
import math
import torch
import cv2
class CustomDataset(torch.utils.data.IterableDataset):
def __init__(self, sources, max_iter=-1):
super().__init__()
# Load data
self.vcap = cv2.VideoCapture(sources[0])
# Load labels
self.labelvec = []
with open(sources[1], "r") as fobj:
for l in fobj:
self.labelvec.append(float(l))
# Store internal parameters
# Frame size
self.frame_size = (self.vcap.get(cv2.CAP_PROP_FRAME_HEIGHT),
self.vcap.get(cv2.CAP_PROP_FRAME_WIDTH))
# Find dataset length property
datum_count = int(self.vcap.get(cv2.CAP_PROP_FRAME_COUNT))
label_count = len(self.labelvec)
if max_iter != -1:
assert max_iter < datum_count, max_iter < label_count
self.length = max_iter
else:
assert datum_count == label_count
self.length = datum_count
# Setup iterator bounds
self.iter_start = 0
self.iter_end = self.length-1
return
def __del__(self):
self.vcap.release()
return
def __getitem__(self, index):
if index < 0 or index >= self.length:
raise IndexError
# Get frame
self.vcap.set(cv2.CAP_PROP_POS_FRAMES, index)
frame = self.vcap.read()[1]
# Get grayscale frame
gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
# Flatten frame
flat_frame = torch.tensor(gray_frame.tolist())
# Get label
label = self.labelvec[index]
return (flat_frame, label)
def __len__(self):
return self.length
def __iter__(self):
# Based on __iter__ from the MyIterableDataset example implemented in
# https://pytorch.org/docs/stable/data.html#multi-process-data-loading
# get worker info
worker_info = torch.utils.data.get_worker_info()
# single-threaded loading
if worker_info is None:
iter_start = self.iter_start
iter_end = self.iter_end
# multi-threaded loading
else:
per_worker = int(math.ceil((self.iter_end - self.iter_start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.iter_start + worker_id*per_worker
iter_end = min(iter_start+per_worker-1, self.iter_end)
for ix in range(iter_start, iter_end):
yield self.__getitem__(ix)
if __name__ == "__main__":
# Calling parameters
data_fpath = sys.argv[1]
labels_fpath = sys.argv[2]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
max_iter = 50
batch_size = 10
num_workers = 3
dataset = CustomDataset((data_fpath,
labels_fpath),
max_iter=max_iter)
dataloader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
num_workers=num_workers)
dataloader_iterator = iter(dataloader)
element = next(dataloader_iterator)
sys.exit(0)
When I run that, the script hangs indefinitely. When I stop it (e.g. by pressing Control + C with focus on the terminal), I get the following output:
File "[redacted]/dataloader_infinite_loop.py", line 115, in <module>
element = next(dataloader_iterator)
File "/usr/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 435, in __next__
data = self._next_data()
File "/usr/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1068, in _next_data
idx, data = self._get_data()
File "/usr/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1034, in _get_data
success, data = self._try_get_data()
File "/usr/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 872, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "/usr/lib/python3.9/multiprocessing/queues.py", line 113, in get
if not self._poll(timeout):
File "/usr/lib/python3.9/multiprocessing/connection.py", line 262, in poll
return self._poll(timeout)
File "/usr/lib/python3.9/multiprocessing/connection.py", line 429, in _poll
r = wait([self], timeout)
File "/usr/lib/python3.9/multiprocessing/connection.py", line 936, in wait
ready = selector.select(timeout)
File "/usr/lib/python3.9/selectors.py", line 416, in select
fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt
If you then browse the source code for File "/usr/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1034
, you’ll see:
1033 while True:
1034 success, data = self._try_get_data()
1035 if success:
1036 return data
1037
Which is where torch gets stuck.
I’m not very confident that anybody will pin this down, but here goes a try.