Hello everyone,
I am quite new to image classification (stemming from the NLP side). I recently followed this Tutorial: https://towardsdatascience.com/binary-image-classification-in-pytorch-5adf64f8c781
However, because I want to integrate a GUI and other parts to the program I hesitate on using juPyter. Therefore I need to be able to save and load the model for further use.
In the Tutorial the Author used a pretrained model:
model = models.resnet18(pretrained=True)
#freeze all params
for params in model.parameters():
params.requires_grad_ = False
#add a new final layer
nr_filters = model.fc.in_features #number of input features of last layer
model.fc = nn.Linear(nr_filters, 1)
model = model.to(device)
I am able to save this model using:
#save best model
if cum_loss <= best_loss:
best_model_wts = model.state_dict()
PATH = './brain_tumor.pth'
torch.save(best_model_wts, PATH)
However I do not know how I should initialise it, when loading it in. I do not have any Class as in the Documentation, and just writing:
model= torch.load("brain_tumor.pth")
does not seem to work, as it returns errors when working with the model later on.
Therefore I am stuck and can not get beyond this point.
The .pth
file you have saves the weights of your model, not the model itself. So change,
to,
loaded = torch.load("filename.pth")
model.load_state_dict(loaded)
1 Like
Thank you for your fast answer.
However if I do not initialise the model how can I get this line to work:
As model is not defined at this point.
My “call_model_and_predict” currently looks like this:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets, transforms
from torchvision import datasets, models, transforms
#func to show images
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
#directory for testing files
testdir = "/Users/fabian/Documents/QB1 Medizinische Informatik/brain_tumor/test"
test_transforms = transforms.Compose([transforms.Resize((224,224)),
transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
test_data = datasets.ImageFolder(testdir,transform=test_transforms)
testloader = torch.utils.data.DataLoader(test_data, shuffle = True, batch_size=4)
device = "cuda" if torch.cuda.is_available() else "cpu"
dataiter = iter(testloader)
images, labels = next(dataiter)
# print images
imshow(torchvision.utils.make_grid(images))
for j in range(4):
print(labels[j])
loaded = torch.load("filename.pth")
model.load_state_dict(loaded)
Can’t you just repeat what you did above? Or do you not have access to the model class?
1 Like