Hello! I am constantly running into a RuntimeError that I cannot seem to understand, and I would be very grateful if someone can explain what is happening and how I can fix it.
I am creating a dataset class and dataloader for music genre classification. The main issue arises when I include a mel-spectrogram transform from torchaudio.
Here is a script that shows the error I have.
import os
import torch
from torch.utils.data import Dataset
from torch.utils import data
import torchaudio
import pandas as pd
sr = 22050
class Mp3Dataset(Dataset):
"""
Mp3 dataset class to work with the FMA dataset.
Input:
df - pandas dataframe containing track_id and genre.
audio_path - directory with mp3 files
duration - how much of the songs to sample
"""
def __init__(self, df: pd.DataFrame, audio_path: str, duration: float):
self.audio_path = audio_path
self.IDs = df['track_id'].astype(str).to_list()
self.genre_list = df.genre.to_list()
self.duration = duration
self.E = torchaudio.sox_effects.SoxEffectsChain()
self.E.append_effect_to_chain("trim", [0, self.duration])
self.E.append_effect_to_chain("rate", [sr])
self.E.append_effect_to_chain("channels", ["1"])
self.mel = torchaudio.transforms.MelSpectrogram(sample_rate=sr)
def __len__(self):
return len(self.IDs)
def __getitem__(self, index):
ID = self.IDs[index]
genre = self.genre_list[index]
# sox: set input file
self.E.set_input_file(self.get_path_from_ID(ID))
# use sox to read in the file using my effects
waveform, _ = self.E.sox_build_flow_effects() # size: [1, len * sr]
melspec = self.mel(waveform)
return melspec, genre
def get_path_from_ID(self, ID):
"""
Gets the audio path from the ID using the FMA dataset format
"""
track_id = ID.zfill(6)
return os.path.join(self.audio_path, track_id[:3], track_id + '.mp3')
if __name__ == '__main__':
# my path to audio files
audio_path = os.path.join('data', 'fma_small')
# my dataframe that has track_id and genre info
df = pd.read_csv('data/fma_metadata/small_track_info.csv')
torchaudio.initialize_sox()
dataset = Mp3Dataset(df, audio_path, 1.0)
params = {'batch_size': 8, 'shuffle': True, 'num_workers': 2}
dataset_loader = data.DataLoader(dataset, **params)
print(next(iter(dataset_loader)))
torchaudio.shutdown_sox()
When I run this code, I get the following errors thrown at me:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
~/Documents/data_sci.nosync/fma/min_ex.py in <module>
71 dataset_loader = data.DataLoader(dataset, **params)
72
---> 73 print(next(iter(dataset_loader)))
74
75 torchaudio.shutdown_sox()
/usr/local/anaconda3/envs/pytorch_fma/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
817 else:
818 del self.task_info[idx]
--> 819 return self._process_data(data)
820
821 next = __next__ # Python 2 compatibility
/usr/local/anaconda3/envs/pytorch_fma/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
844 self._try_put_index()
845 if isinstance(data, ExceptionWrapper):
--> 846 data.reraise()
847 return data
848
/usr/local/anaconda3/envs/pytorch_fma/lib/python3.7/site-packages/torch/_utils.py in reraise(self)
367 # (https://bugs.python.org/issue2651), so we work around it.
368 msg = KeyErrorMessage(msg)
--> 369 raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/anaconda3/envs/pytorch_fma/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/anaconda3/envs/pytorch_fma/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "/usr/local/anaconda3/envs/pytorch_fma/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 80, in default_collate
return [default_collate(samples) for samples in transposed]
File "/usr/local/anaconda3/envs/pytorch_fma/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 80, in <listcomp>
return [default_collate(samples) for samples in transposed]
File "/usr/local/anaconda3/envs/pytorch_fma/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 56, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: stack(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.
Two things I observe with this script:
- Removing the mel-spectrogram transform and outputting the waveform instead removes this error.
- Changing
num_workers
to0
removes this error as well.
To be clear, these observations are independent. Performing only one of these modifications removes the error. However, I would like both workers and a melspectrogram transform
I am using PyTorch 1.2.0 and Torchaudio 0.3.0+bf88aef, as well as python 3.7.4.