Loading and using a module that I didn't trained

Hi guys,

I am trying to run a PyTorch module that I didn’t trained as part of my university assignments, I have the training code and I need to display the module prediction in a simple GUI.
I wrote the following code:

import os
import torch
import cv2
import tkinter as tk
from tkinter import filedialog
import torch.nn.functional as F

cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)

from models.cnn import ShallowCNN
from module.lfcc import LFCC
from DataLoader import AudioDataset
from preprocess.process_audio import get_waveform_alt, get_LFCC

os.environ[“MKL_NUM_THREADS”] = “1”
os.environ[“NUMEXPR_NUM_THREADS”] = “1”
os.environ[“OMP_NUM_THREADS”] = “1”

def pred(video):
device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)

# Load the trained model
model = ShallowCNN(in_features=1, out_dim=1)
checkpoint_path = 'best.pt'  # Path to the downloaded checkpoint file
checkpoint = torch.load(checkpoint_path, map_location=device)

# Modify the state_dict to match the model architecture
state_dict = checkpoint['state_dict']
updated_state_dict = {}
for k, v in state_dict.items():
    if 'conv1' in k:
        updated_state_dict[k.replace('conv1', 'conv1.weight')] = v[:32, :1]
        updated_state_dict[k.replace('conv1', 'conv1.bias')] = v[:32]
    else:
        updated_state_dict[k] = v

# Load the updated state dictionary into the model
model.load_state_dict(updated_state_dict)
model.to(device)
model.eval()

# Create the dataset for the WAV file
wav_path = video  # Specify the path to the WAV file you want to make predictions on
dataset = get_LFCC(get_waveform_alt(wav_path))
# Create a data loader with the dataset
batch_size = 1
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

# Iterate over the data loader and get the predictions
predictions = []
with torch.no_grad():
    for batch in data_loader:
        print(batch.shape)
        batch_x = batch[0]  # Assuming the tensor is at index 0 in the batch
        if not isinstance(batch_x, torch.Tensor):
            batch_x = torch.tensor(batch_x)  # Convert to a tensor if it's not already
        batch_x = batch_x.to(device)
        batch_out = model(batch_x)
        batch_pred = torch.sigmoid(batch_out).cpu().numpy()
        predictions.append(batch_pred)

# Process the predictions as desired (e.g., convert to labels, perform thresholding, etc.)
predictions = [pred.item() for pred in predictions]

# Print or use the predictions as needed
print(predictions)

def choose_file():
file_path = filedialog.askopenfilename(filetypes=[(‘WAV files’, ‘*.wav’)])
file_label.config(text="Selected file: " + file_path)
return file_path

def analyze_video():
video_path = choose_file()
if video_path:
result = pred(video_path)
output_label.config(text=result)

root = tk.Tk()
root.title(“Detector”)

file_label = tk.Label(root, text=“Select an audio file to analyze”)
file_label.pack(pady=10)

file_button = tk.Button(root, text=“Choose File”, command=choose_file)
file_button.pack()

analyze_button = tk.Button(root, text=“Analyze Audio”, command=analyze_video)
analyze_button.pack(pady=20)

output_label = tk.Label(root, text=“”)
output_label.pack()
root.mainloop()

And I got the following error:
“IndexError: too many indices for tensor of dimension 1”

Honestly I am still kind of a newbie so I would appreciate a dummy explanation about what I am doing wrong if you need me to provide more information just say.
Thank you for your help.

It’s unclear which part of the code fails, but the error is raised if you are trying to index a tensor in undefined dimensions as seen here:

x = torch.randn(10)

# indexing the single dimension works
x[0]

# indexing multiple dimension fails
x[0, 0]
# IndexError: too many indices for tensor of dimension 1

I guess v is a single-dimension tensor so check its shape and make sure to index it with a single index value if that’s indeed the case.

It seems the loaded checkpoint does not match the currently used model architecture.
You would have to make sure the currently used layers in your model have the same shape as the ones used to create the state_dict.

The new issue seems to come from F.interpolate as described in this topic.