DataLoader Runtime Error with Melspectrograms and workers

Hello! I am constantly running into a RuntimeError that I cannot seem to understand, and I would be very grateful if someone can explain what is happening and how I can fix it.

I am creating a dataset class and dataloader for music genre classification. The main issue arises when I include a mel-spectrogram transform from torchaudio.

Here is a script that shows the error I have.

import os
import torch
from torch.utils.data import Dataset
from torch.utils import data
import torchaudio
import pandas as pd

sr = 22050

class Mp3Dataset(Dataset):
    """
    Mp3 dataset class to work with the FMA dataset.
    Input:
    df - pandas dataframe containing track_id and genre.
    audio_path - directory with mp3 files
    duration - how much of the songs to sample
    """

    def __init__(self, df: pd.DataFrame, audio_path: str, duration: float):

        self.audio_path = audio_path
        self.IDs = df['track_id'].astype(str).to_list()
        self.genre_list = df.genre.to_list()
        self.duration = duration

        self.E = torchaudio.sox_effects.SoxEffectsChain()
        self.E.append_effect_to_chain("trim", [0, self.duration])
        self.E.append_effect_to_chain("rate", [sr])
        self.E.append_effect_to_chain("channels", ["1"])

        self.mel = torchaudio.transforms.MelSpectrogram(sample_rate=sr)

    def __len__(self):
        return len(self.IDs)

    def __getitem__(self, index):
        ID = self.IDs[index]
        genre = self.genre_list[index]

        # sox: set input file
        self.E.set_input_file(self.get_path_from_ID(ID))

        # use sox to read in the file using my effects
        waveform, _ = self.E.sox_build_flow_effects()  # size: [1, len * sr]

        melspec = self.mel(waveform)

        return melspec, genre

    def get_path_from_ID(self, ID):
        """
        Gets the audio path from the ID using the FMA dataset format
        """
        track_id = ID.zfill(6)

        return os.path.join(self.audio_path, track_id[:3], track_id + '.mp3')

if __name__ == '__main__':

    # my path to audio files
    audio_path = os.path.join('data', 'fma_small')

    # my dataframe that has track_id and genre info
    df = pd.read_csv('data/fma_metadata/small_track_info.csv')

    torchaudio.initialize_sox()

    dataset = Mp3Dataset(df, audio_path, 1.0)

    params = {'batch_size': 8, 'shuffle': True, 'num_workers': 2}

    dataset_loader = data.DataLoader(dataset, **params)

    print(next(iter(dataset_loader)))

    torchaudio.shutdown_sox()

When I run this code, I get the following errors thrown at me:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~/Documents/data_sci.nosync/fma/min_ex.py in <module>
     71     dataset_loader = data.DataLoader(dataset, **params)
     72
---> 73     print(next(iter(dataset_loader)))
     74
     75     torchaudio.shutdown_sox()

/usr/local/anaconda3/envs/pytorch_fma/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    817             else:
    818                 del self.task_info[idx]
--> 819                 return self._process_data(data)
    820
    821     next = __next__  # Python 2 compatibility

/usr/local/anaconda3/envs/pytorch_fma/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
    844         self._try_put_index()
    845         if isinstance(data, ExceptionWrapper):
--> 846             data.reraise()
    847         return data
    848

/usr/local/anaconda3/envs/pytorch_fma/lib/python3.7/site-packages/torch/_utils.py in reraise(self)
    367             # (https://bugs.python.org/issue2651), so we work around it.
    368             msg = KeyErrorMessage(msg)
--> 369         raise self.exc_type(msg)

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/anaconda3/envs/pytorch_fma/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/anaconda3/envs/pytorch_fma/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/usr/local/anaconda3/envs/pytorch_fma/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 80, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/usr/local/anaconda3/envs/pytorch_fma/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 80, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/usr/local/anaconda3/envs/pytorch_fma/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 56, in default_collate
    return torch.stack(batch, 0, out=out)
RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

Two things I observe with this script:

  1. Removing the mel-spectrogram transform and outputting the waveform instead removes this error.
  2. Changing num_workers to 0 removes this error as well.

To be clear, these observations are independent. Performing only one of these modifications removes the error. However, I would like both workers and a melspectrogram transform

I am using PyTorch 1.2.0 and Torchaudio 0.3.0+bf88aef, as well as python 3.7.4.

1 Like