AttributeError when using Dataset and DataLoader

Hi, I am trying to load the audio files using DataSet and Dataloader, however it shows the following error, it seems to link the iter of the dataset. Could you help me solve it?

import torch
import numpy as np
import os
import matplotlib.pyplot as plt
import torchaudio
from torch.utils.data import Dataset, DataLoader

class MyData(Dataset):
def init(self, root_dir, transform):
self.root_dir = root_dir
self.list_files = os.listdir(self.root_dir)
self.transform = transform

 def __getitem__(self,idx):
      wave_file = self.list_files[idx]
      data_path = os.path.join(self.roo_dir,wave_file)
      waveform, samp_freq = torchaudio.load(data_path)
      return waveform, samp_freq

 def __len__(self):
     return len(self.list_files)

if __name__ == "__main__":
   dataset = MyData(root_dir='E:\SSR\AudioAugmentation\AudioFile', transform = None)
   loader = DataLoader(dataset, batch_size=1)

   for waveform,f in loader:
       print(waveform.size())
       break

The error shows below:

Traceback (most recent call last):
  File "E:/SSR/AudioAugmentation/dataset.py", line 28, in <module>
    for waveform,f in loader:
  File "C:\Users\.conda\envs\lib\site-packages\torch\utils\data\dataloader.py", line 521, in __next__
    data = self._next_data()
  File "C:\Users\.conda\envs\lib\site-packages\torch\utils\data\dataloader.py", line 561, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "C:\Users\.conda\envs\lib\site-packages\torch\utils\data\_utils\fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "C:\Users\.conda\envs\lib\site-packages\torch\utils\data\_utils\fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "E:/SSR/AudioAugmentation/dataset.py", line 17, in __getitem__
    data_path = os.path.join(self.roo_dir,wave_file)
  File "C:\Users\.conda\envs\lib\site-packages\torch\utils\data\dataset.py", line 83, in __getattr__
    raise AttributeError
AttributeError

Thanks!!!

You have a typo here:

data_path = os.path.join(self.roo_dir,wave_file)

and would need to use self.root_dir to avoid the AttributeError.

1 Like