TypeError: __init__() missing 1 required positional argument: 'Input'

here is my code:


import torch
import torchaudio
import matplotlib.pyplot as plt
import os
import torch.nn as nn
import soundfile as sf
from pydub import AudioSegment
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")


AUDIO_DIR = './data/audio/'
IDX = 1
REUQESTED_FILE = str(IDX) + '.wav'
FilePath = os.path.join(str(AUDIO_DIR), str(REUQESTED_FILE))

def waveformToMelSpecrogram(wavFilePath):
    waveform, sample_rate = torchaudio.load(wavFilePath)
    waveform, sample_rate = sf.read(wavFilePath)
    waveform = AudioSegment.from_mp3(wavFilePath)
    waveform = waveform.set_channels(1)
    waveform = waveform.get_array_of_samples()
    waveform = torch.tensor(waveform, dtype=torch.float)
    waveform = torch.reshape(waveform, (1, waveform.shape[0]))
    Mel_Spectrogram = torchaudio.transforms.MelSpectrogram()(waveform)
    return Mel_Spectrogram


class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(128*1756, 2),
            nn.ReLU()
        )

        def forward(self, x):
            x = self.flatten(x)
            logits = self.linear_relu_stack(x)
            return logits


specgram = waveformToMelSpecrogram(FilePath)
'''print("Shape of spectrogram: {}".format(specgram.size()))

plt.figure()
plt.imshow(specgram.log2()[0,:,:].numpy())
plt.show()
'''
net = NN().to(device)
print(NN())

what did I do wrong? right now I’m just trying to pass a mel spectrogram through a simple net.

Could you post the entire error message including the stacktrace, as I’m currently unsure which object raises the error, please?

I managed to solve it myself, thanks anyway

Even if you solved it yourself, posting something about the solution is much appreciated. Because, google ends up at this question when someone search for similar error message text …