Katono88
(Katono Wong)
January 23, 2023, 4:22pm
1
I have already write a class to load the datasets, and tried to load my data
I used (code below) to load my datasets
train_dataloader = DataLoader(
dataset=train_data,
batch_size=PARAMS.batch_size, # batch_size = 16
num_workers=PARAMS.num_workers, # num_worker = 3
shuffle=True,
)
But it returns
RuntimeError: Given groups=1, weight of size [8, 1, 5, 5], expected input[1, 64, 1, 16000] to have 1 channels, but got 64 channels instead
To my previous experience, the first one should be the batch_size , and the second should be the input channels, but it just don’t work out here, Im confused.
Katono88
(Katono Wong)
January 23, 2023, 4:23pm
2
This is my Load_datasets class code
import torch
import torch.utils.data as Data
import torchaudio
import glob
from params3 import *
param = Params()
class Read_datasets(Data.Dataset):
def __init__(self, Params, transform, target_samples, target_sr, isTrain=True):
super().__init__()
if isTrain:
path = Params.train_dir
else:
path = Params.valid_dir
files = []
labels = []
self.cache = []
for class_label, class_dir in enumerate(glob.glob(path + '/*')):
for audio_file in glob.glob(class_dir + '/*'):
files.append(audio_file)
labels.append(class_label)
self.cache.append(False)
self.files = files
self.labels = labels
self.target_samples = target_samples
self.target_sr = target_sr
self.transform = transform
def __getitem__(self, index):
if self.cache[index] is False:
audio = self.files[index]
waveform, sr = torchaudio.load(audio)
waveform = reshape_if_necessary(waveform, self.target_samples)
waveform = resample_if_necessary(waveform, sr, self.target_sr)
waveform = mix_down_if_necessary(waveform)
self.cache[index] = waveform, self.labels[index]
return self.cache[index]
def __len__(self):
return len(self.files)
transform_mel = torchaudio.transforms.MelSpectrogram(sample_rate=44100, n_mels=40)
def reshape_if_necessary(waveform, target_samples):
if waveform.shape[1] > target_samples:
waveform = waveform[:, 0:target_samples]
elif waveform.shape[1] < target_samples:
num_padding = target_samples - waveform.shape[1]
waveform = torch.nn.functional.pad(waveform, pad=(0, num_padding))
return waveform
def resample_if_necessary(waveform, sr, target_sr):
if sr != target_sr:
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=target_sr)
return waveform
def mix_down_if_necessary(waveform):
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
return waveform
eqy
January 23, 2023, 8:32pm
3
The dataloader should be calling the default collate function if you are not passing your own to construct the batch, so it might be useful to inspect what happens when you manually pass a list of samples to the collate function. For example, something like
torch.utils.data.default_collate([mydataset[i] for i in range(batch_size)])
1 Like