Dataloader error when num_workers > 0

Hello. I am trying to load a video frame by frame using imageio, but facing an weird error when num_workers > 0. This is a short snippet to reproduce the error.

from skvideo import datasets
import imageio
import torch
from torch.utils.data import Dataset, DataLoader

class VideoDataset(Dataset):
    def __init__(self, video_path):
        super().__init__()
        self.video_reader = imageio.get_reader(video_path)
        self.metadata = self.video_reader.get_meta_data()
        self.nframes = self.metadata["nframes"]

    def __getitem__(self, idx):
        img = self.video_reader.get_data(idx)
        img = torch.from_numpy(img)

        return img

    def __len__(self):
        return self.nframes
video_path = datasets.bikes()
video_data = VideoDataset(video_path)
video_loader = DataLoader(video_data, batch_size=4, num_workers=4)

for imgs in video_loader:
    print(imgs.shape)

when I try to iteratively print the shape of the loaded tensors, it gets stuck mid way and throws an error:

torch.Size([8, 272, 640, 3])
torch.Size([8, 272, 640, 3])
torch.Size([8, 272, 640, 3])
torch.Size([8, 272, 640, 3])
torch.Size([8, 272, 640, 3])
torch.Size([8, 272, 640, 3])
torch.Size([8, 272, 640, 3])
torch.Size([8, 272, 640, 3])
---------------------------------------------------------------------------
CannotReadFrameError                      Traceback (most recent call last)
<ipython-input-20-453bfa4da85a> in <module>()
----> 1 for imgs in video_loader:
      2     print(imgs.shape)

Note that the code works fine when num_workers=0. I can’t figure out what’s going wrong.

class VideoDataset(Dataset):
    def __init__(self, video_path):
        super().__init__()
        self.video_path = video_path
        self.video_reader = imageio.get_reader(video_path)
        self.metadata = self.video_reader.get_meta_data()
        self.nframes = self.metadata["nframes"]

    def __getitem__(self, idx):
        video = imageio.get_reader(self.video_path)
        img = video.get_data(idx)
        img = torch.from_numpy(img)

        return img

    def __len__(self):
        return self.nframes

then it works