TypeError: __init__() missing 1 required positional argument: 'Input'

here is my code:


import torch
import torchaudio
import matplotlib.pyplot as plt
import os
import torch.nn as nn
import soundfile as sf
from pydub import AudioSegment
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")


AUDIO_DIR = './data/audio/'
IDX = 1
REUQESTED_FILE = str(IDX) + '.wav'
FilePath = os.path.join(str(AUDIO_DIR), str(REUQESTED_FILE))

def waveformToMelSpecrogram(wavFilePath):
    waveform, sample_rate = torchaudio.load(wavFilePath)
    waveform, sample_rate = sf.read(wavFilePath)
    waveform = AudioSegment.from_mp3(wavFilePath)
    waveform = waveform.set_channels(1)
    waveform = waveform.get_array_of_samples()
    waveform = torch.tensor(waveform, dtype=torch.float)
    waveform = torch.reshape(waveform, (1, waveform.shape[0]))
    Mel_Spectrogram = torchaudio.transforms.MelSpectrogram()(waveform)
    return Mel_Spectrogram


class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(128*1756, 2),
            nn.ReLU()
        )

        def forward(self, x):
            x = self.flatten(x)
            logits = self.linear_relu_stack(x)
            return logits


specgram = waveformToMelSpecrogram(FilePath)
'''print("Shape of spectrogram: {}".format(specgram.size()))

plt.figure()
plt.imshow(specgram.log2()[0,:,:].numpy())
plt.show()
'''
net = NN().to(device)
print(NN())

what did I do wrong? right now I’m just trying to pass a mel spectrogram through a simple net.

Could you post the entire error message including the stacktrace, as I’m currently unsure which object raises the error, please?

I managed to solve it myself, thanks anyway