Load a model without initialising it

Hello everyone,

I am quite new to image classification (stemming from the NLP side). I recently followed this Tutorial: https://towardsdatascience.com/binary-image-classification-in-pytorch-5adf64f8c781

However, because I want to integrate a GUI and other parts to the program I hesitate on using juPyter. Therefore I need to be able to save and load the model for further use.

In the Tutorial the Author used a pretrained model:

  model = models.resnet18(pretrained=True)

  #freeze all params
  for params in model.parameters():
    params.requires_grad_ = False

  #add a new final layer
  nr_filters = model.fc.in_features  #number of input features of last layer
  model.fc = nn.Linear(nr_filters, 1)

  model = model.to(device)

I am able to save this model using:

#save best model
if cum_loss <= best_loss:
    best_model_wts = model.state_dict()
    PATH = './brain_tumor.pth'
    torch.save(best_model_wts, PATH)

However I do not know how I should initialise it, when loading it in. I do not have any Class as in the Documentation, and just writing:

model= torch.load("brain_tumor.pth")

does not seem to work, as it returns errors when working with the model later on.

Therefore I am stuck and can not get beyond this point.

The .pth file you have saves the weights of your model, not the model itself. So change,

to,

loaded = torch.load("filename.pth")
model.load_state_dict(loaded)
1 Like

Thank you for your fast answer.

However if I do not initialise the model how can I get this line to work:

As model is not defined at this point.

My “call_model_and_predict” currently looks like this:

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets, transforms
from torchvision import datasets, models, transforms


#func to show images
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


#directory for testing files
testdir = "/Users/fabian/Documents/QB1 Medizinische Informatik/brain_tumor/test"

test_transforms = transforms.Compose([transforms.Resize((224,224)),
                                      transforms.ToTensor(),
                                      torchvision.transforms.Normalize(
                                          mean=[0.485, 0.456, 0.406],
                                          std=[0.229, 0.224, 0.225],
    ),
                                      ])

test_data = datasets.ImageFolder(testdir,transform=test_transforms)

testloader = torch.utils.data.DataLoader(test_data, shuffle = True, batch_size=4)

device = "cuda" if torch.cuda.is_available() else "cpu"

dataiter = iter(testloader)
images, labels = next(dataiter)

# print images
imshow(torchvision.utils.make_grid(images))
for j in range(4):
    print(labels[j])


loaded = torch.load("filename.pth")
model.load_state_dict(loaded)

Can’t you just repeat what you did above? Or do you not have access to the model class?

1 Like