EOFError when using one dataset (that loads from HDF5) with multiple Dataloaders

Hi,

I’m testing different Dataloader parameter settings as I recently found out that for num_workers > 0 to to actual aid in loading speed on windows you need to set persistent_workers = True. For this I’m using the below code.

class Dset(torch.utils.data.Dataset):

    def __init__(self, index_dict_fp, labels, X_filepath, y_filepath, sr=48000, test=None):
        self.X_filepath = X_filepath
        self.y_filepath = y_filepath

        # load in and unpickle index dictionary
        index_dict_temp = open(index_dict_fp, 'rb')
        index_dict = pickle.load(index_dict_temp)
        index_dict_temp.close()
        self.index = index_dict

        # load in fold info - lets us know which samples to grab for training/testing for this fold
        with open(labels) as f:
            self.labels = [line.rstrip('\n') for line in f] # gets file names to use for indexing

        # init transformations - at some point this needs to be parameterised so user can define transformations.

        self.Mel_spectra = torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=1024, hop_length=512, n_mels=64,
                                                                win_length=1024, center=False)
        self.amp2db = torchaudio.transforms.AmplitudeToDB()
        self.linear_spectra = torchaudio.transforms.Spectrogram(n_fft=1024, hop_length=512, center=False,
                                                                win_length=1024, power=None,
                                                                return_complex=True)  # needs to return complex

        self.test = test

    def open_hdf5(self):
        self.h5_Xfile = h5py.File(self.X_filepath, 'r')  # load h5 file for inputs
        self.h5_yfile = h5py.File(self.y_filepath, 'r')  # load file for target
        self.input = self.h5_Xfile['stereo_data']
        self.target = self.h5_yfile['data']

    def close_hdf5(self):
        self.h5_Xfile.close()
        self.h5_yfile.close()

    def normalise(self, x):
        x_norm = (x - torch.mean(x)) / torch.std(x)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        if not hasattr(self, 'input'):
            self.open_hdf5()  # allows working with num_workers > 1 - checks if file has already been opened

        # get label for sample, use label to find index

        sample = self.labels[index]
        sample_idx = self.index[sample]

        # load training example
        X = torch.from_numpy(self.input[sample_idx, :, :])
        y = torch.from_numpy(self.target[sample_idx, :, :, :])

        _Nhop = np.ceil(X.shape[0] / 512) + 2

        # extract mel and gcc features from stereo signal
        # X is padded to match the padding used in DirAC analysis

        X = F.pad(X, (0, 0, 512, int(_Nhop * 512 - X.shape[0] - 512)))
        Xl = self.Mel_spectra(X[:, 0])
        Xl = self.amp2db(Xl)
        Xr = self.Mel_spectra(X[:, 1])
        Xr = self.amp2db(Xl)

        Xlog_mel = torch.stack((Xl, Xr), dim=2)  # (freq_bins, frames, channels)

        # normalise logmel spec to zero mean and unit variance
        Xlog_mel = (Xlog_mel - torch.mean(Xlog_mel)) / torch.std(Xlog_mel)

        if self.test:
            X, fs = sf.read(self.test)
            X = torch.from_numpy(X)

        # Extract GCCs
        Xl = self.linear_spectra(X[:, 0])
        Xr = self.linear_spectra(X[:, 1])

        lin_spec = torch.stack((Xl, Xr), dim=2)
        lin_spec = lin_spec.transpose(0, 1)  # (frames, freq_bins, channels)

        R = torch.conj(lin_spec[:, :, 0]) * lin_spec[:, :, 1]

        cc = torch.fft.irfft(torch.exp(1.j * torch.angle(R)))
        cc = torch.cat((cc[:, -Xlog_mel.shape[0] // 2:], cc[:, :Xlog_mel.shape[0] // 2]), dim=-1)

        features = torch.cat((Xlog_mel.transpose(0, 2), cc[None, :, :]), dim=0)  # (feature, frames, freq_bins)
        # y = y.permute((2, 0, 1))
        return features, y


def train(dataloader):
    start = time.time()
    for _ in tqdm(range(2)):
        for x in tqdm(dataloader):
            pass
    end = time.time()
    return end - start


if __name__ == '__main__':
    index_dict_fp = 'path\to\DirAC_indxes_dict'
    fold_info = 'path\to\fold1_train.txt'
    input_h5 = 'path\to\input.h5'
    target_h5 = 'path\to\target.h5'
    train_dataset = Dset(index_dict_fp=index_dict_fp, labels=fold_info, X_filepath=input_h5,
                                             y_filepath=target_h5)

    batch_size = 32
    train_loader1 = torch.utils.data.DataLoader(dataset=train_dataset,
                                                batch_size=batch_size, shuffle=True, num_workers=0)

    train_loader2 = torch.utils.data.DataLoader(dataset=train_dataset,
                                                batch_size=batch_size, shuffle=True, num_workers=8)

    train_loader3 = torch.utils.data.DataLoader(dataset=train_dataset,
                                                batch_size=batch_size, shuffle=True, num_workers=8,
                                                persistent_workers=True)

    print(train(train_loader1))
    train_dataset.close_hdf5()

    print(train(train_loader2))
    train_dataset.close_hdf5()

    print(train(train_loader3))
    train_dataset.close_hdf5()

However after the first print(train(trainloader)) call it crashed with the following error and stack trace. I had assumed this was because the HDF5 file didn’t like the dataset class trying to reopen it when it was already open, so I tried to close it after each call to train() using dset.close_hdf5(). This still hasn’t worked and I still get the same error and stack trace which is below. It does seem like it’s something to do with the HDF5 file and the pickle process - but I don’t know enough how about the Dataset and Dataloader classes operate under the hood to figure it out.

Traceback (most recent call last):
  File "C:\Users\Audio\anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 3437, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-e34725483dd5>", line 1, in <module>
    runfile('D:/Dan_PC_Stuff/Pycharm_projects/DirAC_net/time_dataloader.py', wdir='D:/Dan_PC_Stuff/Pycharm_projects/DirAC_net')
  File "C:\Program Files\JetBrains\PyCharm 2021.1.3\plugins\python\helpers\pydev\_pydev_bundle\pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "C:\Program Files\JetBrains\PyCharm 2021.1.3\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "D:/Dan_PC_Stuff/Pycharm_projects/DirAC_net/time_dataloader.py", line 40, in <module>
    print(train(train_loader2))
  File "D:/Dan_PC_Stuff/Pycharm_projects/DirAC_net/time_dataloader.py", line 12, in train
    for x in tqdm(dataloader):
  File "C:\Users\Audio\anaconda3\lib\site-packages\tqdm\std.py", line 1178, in __iter__
    for obj in iterable:
  File "C:\Users\Audio\anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 359, in __iter__
    return self._get_iterator()
  File "C:\Users\Audio\anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 305, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "C:\Users\Audio\anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 918, in __init__
    w.start()
  File "C:\Users\Audio\anaconda3\lib\multiprocessing\process.py", line 121, in start
    self._popen = self._Popen(self)
  File "C:\Users\Audio\anaconda3\lib\multiprocessing\context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "C:\Users\Audio\anaconda3\lib\multiprocessing\context.py", line 327, in _Popen
    return Popen(process_obj)
  File "C:\Users\Audio\anaconda3\lib\multiprocessing\popen_spawn_win32.py", line 93, in __init__
    reduction.dump(process_obj, to_child)
  File "C:\Users\Audio\anaconda3\lib\multiprocessing\reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
  File "C:\Users\Audio\anaconda3\lib\site-packages\h5py\_hl\base.py", line 308, in __getnewargs__
    raise TypeError("h5py objects cannot be pickled")
TypeError: h5py objects cannot be pickled
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "C:\Users\Audio\anaconda3\lib\multiprocessing\spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "C:\Users\Audio\anaconda3\lib\multiprocessing\spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
EOFError: Ran out of input

It seems you might be running into this issue and the linked post refers to this post which explains the limitation of Windows to fork processes.

1 Like