Input error, but I cannot figure out why

I have already write a class to load the datasets, and tried to load my data
I used (code below) to load my datasets

    train_dataloader = DataLoader(
        dataset=train_data,
        batch_size=PARAMS.batch_size,                                # batch_size = 16
        num_workers=PARAMS.num_workers,                       # num_worker = 3
        shuffle=True,
    )

But it returns

RuntimeError: Given groups=1, weight of size [8, 1, 5, 5], expected input[1, 64, 1, 16000] to have 1 channels, but got 64 channels instead

To my previous experience, the first one should be the batch_size, and the second should be the input channels, but it just don’t work out here, Im confused.

This is my Load_datasets class code

import torch
import torch.utils.data as Data
import torchaudio
import glob
from params3 import *
param = Params()

class Read_datasets(Data.Dataset):
    def __init__(self, Params, transform, target_samples, target_sr, isTrain=True):
        super().__init__()
        if isTrain:
            path = Params.train_dir
        else:
            path = Params.valid_dir

        files = []
        labels = []
        self.cache = []
        for class_label, class_dir in enumerate(glob.glob(path + '/*')):
            for audio_file in glob.glob(class_dir + '/*'):
                files.append(audio_file)
                labels.append(class_label)
                self.cache.append(False)
        self.files = files
        self.labels = labels
        self.target_samples = target_samples
        self.target_sr = target_sr
        self.transform = transform

    def __getitem__(self, index):
        if self.cache[index] is False:
            audio = self.files[index]
            waveform, sr = torchaudio.load(audio)
            waveform = reshape_if_necessary(waveform, self.target_samples)
            waveform = resample_if_necessary(waveform, sr, self.target_sr)
            waveform = mix_down_if_necessary(waveform)
            self.cache[index] = waveform, self.labels[index]
        return self.cache[index]

    def __len__(self):
        return len(self.files)


transform_mel = torchaudio.transforms.MelSpectrogram(sample_rate=44100, n_mels=40)

def reshape_if_necessary(waveform, target_samples):
    if waveform.shape[1] > target_samples:
        waveform = waveform[:, 0:target_samples]
    elif waveform.shape[1] < target_samples:
        num_padding = target_samples - waveform.shape[1]
        waveform = torch.nn.functional.pad(waveform, pad=(0, num_padding))
    return waveform

def resample_if_necessary(waveform, sr, target_sr):
    if sr != target_sr:
        waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=target_sr)
    return waveform

def mix_down_if_necessary(waveform):
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    return waveform

The dataloader should be calling the default collate function if you are not passing your own to construct the batch, so it might be useful to inspect what happens when you manually pass a list of samples to the collate function. For example, something like

torch.utils.data.default_collate([mydataset[i] for i in range(batch_size)])
1 Like