List index out of range when running pytorch Dataloader

Hi everyone,

I have made a custom dataloader which will take files from two directory and put them in a dict. This is my code for the same.

###DataLoder####

import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import os
import hyperparams as hp
import librosa
from utils import get_spectrograms
from tqdm import tqdm
import glob
class PrepareDataset(Dataset):

def __init__(self, csv_file, source_dir, target_dir):
    # self.landmarks_frame = pd.read_csv(csv_file, sep='|', header=None)
    self.source_dir = source_dir
    self.target_dir = target_dir
    self.data = pd.read_csv(csv_file, sep=',', header=None)

def load_wav(self, filename):
    return librosa.load(filename, sr=hp.sample_rate)

def __len__(self):
    return len([f for f in os.listdir(self.source_dir)if os.path.isfile(os.path.join(self.source_dir, f))])

def __getitem__(self, idx):

    source_file = glob.glob(self.source_dir + '/*.wav')
    target_file = glob.glob(self.target_dir + '/*.wav')
    source_wav_name = os.path.join(self.source_dir, source_file[idx]) 
    mel_source, mag_source = get_spectrograms(source_wav_name)
    
    target_wav_name = os.path.join(self.target_dir, target_file[idx]) 
    mel_target, mag_target = get_spectrograms(target_wav_name)
    
    np.save(source_wav_name[:-4] + '.pt', mel_source)
    np.save(source_wav_name[:-4] + '.mag', mag_source)
    np.save(target_wav_name[:-4] + '.pt', mel_target)
    np.save(target_wav_name[:-4] + '.mag', mag_target)

    sample = {'source_mel':mel_source, 'source_mag': mag_source, 'target_mel':mel_target, 'target_mag':mag_target}

    return sample

root_dir = ‘/content/en-hi’
dataset = PrepareDataset(os.path.join(root_dir,‘metadata.csv’), os.path.join(root_dir,‘Source’), os.path.join(root_dir,‘Target’))
dataloader = DataLoader(dataset, batch_size=1, drop_last=False, num_workers=0)
pbar = tqdm(dataloader)
for d in pbar:
pass

The issue is when i am trying to unpack the dataloader using enumerate i am getting a ‘List index out of range error’. I was successful in finding what the issue is, but i dont know how to solve it. Apparantly the idx in getitem() method is runnning even after reaching the end of the list.

Is there any way in which i can stop the idx loop after it reached the end of the list.

Thank you

Hi,
This might be directly related to the __len__ method.

Many implementations of the Sampler classes that are used by the DataLoader for sampling indices directly use the len method.

Make sure the method returns the correct length of your dataset, and also that the indices of your Map style dataset match with the indices the Sampler is using. I guess the latter is already taken care of in your case.

Hi thanks for the reply, I was able to solve the problem