I am trying to use wav2vec2_base model for audio feature extraction. I am getting the error in the title for this code:
class Identity(nn.Module): @overrides def forward(self, input_): return input_
model = wav2vec2_base(num_out=32) model.load_state_dict(torch.load("wav2vec2-base-960h.pt")) model.encoder = Identity()
sample_rate, samples = wavfile.read(wav_file) samples = samples.reshape((1,samples.shape)) samples = torch.from_numpy(samples) #samples = samples.type(torch.short) print("Start") samples = samples.float() output = model(samples.float())
I have tried to select the related parts in the code but there is not much left besides these. What is wrong with my forward function? Thanks for any help!