Hi everyone,
I have made a custom dataloader which will take files from two directory and put them in a dict. This is my code for the same.
###DataLoder####
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import os
import hyperparams as hp
import librosa
from utils import get_spectrograms
from tqdm import tqdm
import glob
class PrepareDataset(Dataset):
def __init__(self, csv_file, source_dir, target_dir):
# self.landmarks_frame = pd.read_csv(csv_file, sep='|', header=None)
self.source_dir = source_dir
self.target_dir = target_dir
self.data = pd.read_csv(csv_file, sep=',', header=None)
def load_wav(self, filename):
return librosa.load(filename, sr=hp.sample_rate)
def __len__(self):
return len([f for f in os.listdir(self.source_dir)if os.path.isfile(os.path.join(self.source_dir, f))])
def __getitem__(self, idx):
source_file = glob.glob(self.source_dir + '/*.wav')
target_file = glob.glob(self.target_dir + '/*.wav')
source_wav_name = os.path.join(self.source_dir, source_file[idx])
mel_source, mag_source = get_spectrograms(source_wav_name)
target_wav_name = os.path.join(self.target_dir, target_file[idx])
mel_target, mag_target = get_spectrograms(target_wav_name)
np.save(source_wav_name[:-4] + '.pt', mel_source)
np.save(source_wav_name[:-4] + '.mag', mag_source)
np.save(target_wav_name[:-4] + '.pt', mel_target)
np.save(target_wav_name[:-4] + '.mag', mag_target)
sample = {'source_mel':mel_source, 'source_mag': mag_source, 'target_mel':mel_target, 'target_mag':mag_target}
return sample
root_dir = ‘/content/en-hi’
dataset = PrepareDataset(os.path.join(root_dir,‘metadata.csv’), os.path.join(root_dir,‘Source’), os.path.join(root_dir,‘Target’))
dataloader = DataLoader(dataset, batch_size=1, drop_last=False, num_workers=0)
pbar = tqdm(dataloader)
for d in pbar:
pass
The issue is when i am trying to unpack the dataloader using enumerate i am getting a ‘List index out of range error’. I was successful in finding what the issue is, but i dont know how to solve it. Apparantly the idx in getitem() method is runnning even after reaching the end of the list.
Is there any way in which i can stop the idx loop after it reached the end of the list.
Thank you