HI ,I have trained my model on my dataset and saved as .pth file using pytorch. Now i want to pass a test image to model to classify whether is it type menigioma tumor or glioma.Please guide what to pass ( from my nn module ,train or test details and parameters ) to test image to classify accucrately. I loaded the model but not sure what to do next
I have seen an example over here: [https://stackoverflow.com/questions/53844826/how-to-load-and-use-a-pretained-pytorch-inceptionv3-model-to-classify-an-image]
# load the model
import torch
model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')
# try to get it to classify an image
imsize = 256
loader = transforms.Compose([transforms.Scale(imsize), transforms.ToTensor()])
def image_loade[How to load and use a pretained PyTorch InceptionV3 model to classify an image](https://stackoverflow.com/questions/53844826/how-to-load-and-use-a-pretained-pytorch-inceptionv3-model-to-classify-an-image)r(image_name):
"""load image, returns cuda tensor"""
image = Image.open(image_name)
image = loader(image).float()
image = Variable(image, requires_grad=True)
image = image.unsqueeze(0)
return image.cpu() #assumes that you're using CPU
image = image_loader("test-image.jpg")
You would have to load the model, test image, process it according to your validation processing, and classify it.
Here is some (pseudo-)code:
model = MyModel() # Initialize model
model.load_state_dict(torch.load(PATH_TO_MODEL)) # Load pretrained parameters
model.eval() # Set to eval mode to change behavior of Dropout, BatchNorm
transform = transforms.Compose(...) # Same as for your validation data, e.g. Resize, ToTensor, Normalize, ...
img = Image.open(PATH_TO_IMAGE) # Load image as PIL.Image
x = transform(img) # Preprocess image
x = x.unsqueeze(0) # Add batch dimension
output = model(x) # Forward pass
pred = torch.argmax(output, 1) # Get predicted class if multi-class classification
print('Image predicted as ', pred)
Let me know, if you get stuck somewhere.
this is how I changes according to your pseudo-)code:
`from __future__ import division
import torch
torch.manual_seed(0)
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
from PIL import Image, ImageOps, ImageEnhance
model=torch.load('last_brain1.pth')
#print(mod)
model = model() # Initialize model
model=torch.load('last_brain1.pth') # Load pretrained parameters
model.eval() # Set to eval mode to change behavior of Dropout, BatchNorm
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # Same as for your validation data, e.g. Resize, ToTensor, Normalize, ...
img = Image.open('C:\dataset\FYP\210.png') # Load image as PIL.Image
x = transform(img) # Preprocess image
x = x.unsqueeze(0) # Add batch dimension
output = model(x) # Forward pass
pred = torch.argmax(output, 1) # Get predicted class if multi-class classification
print('Image predicted as ', pred)`
here is script of my model:
#testing
loss=0.0
correct=0
total=0.0
itr=0
model.eval()
for images,labels in testloader:
images=Variable(images)
labels=Variable(labels)
#CUDA=torch.cuda.is_available()
#if CUDA:
#images=images.cuda()
#labels=labels.cuda()
outputs=model(images)
loss=criterion(outputs,labels)
loss+=loss.item()
_,predicted=torch.max(outputs,1)
total += labels.size(0)
correct+=(predicted==labels).sum()
itr+=1
testloss.append(loss/itr)
testaccuracy.append((100*correct/len(testset)))
print('training loss:%f %%' %(itrloss/itr))
print('training accuracy:%f %%'%(100*correct/len(trainset)))
print('test loss:%f %%'%(loss/itr))
print('test accuracy:%f %%'%((100*correct/len(testset))))
print('Accuracy of the network on the test images: %d %%' % (
100 * correct / total))
class_correct = list(0 for i in range(3))
class_total = list(0 for i in range(3))
with torch.no_grad():
for data in testloader:
images, labels = data
images = Variable(images)
labels = Variable(labels)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
#class_total=[]
for i in range(labels.size(0)):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(3):
print('Accuracy of %5s : %2f %%' % ( classes[i], 100 * class_correct[i] / class_total[i]))
when I have changed according to ur pseudo code.
i am stuck here
Hi MisBah-Awan,
You had replace the model by the dictionary. You should change the line 20 like this I think:
model = model()
model.load_state_dict(torch.load('last_brain1.pth')) #line to change
model.eval()
Let me know if you’re still stuck.
1 Like
Line 16 is also wrong. You need to create your model and then to import the weights.
Like
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
sorry i am not getting the “TheModelClass”
model = TheModelClass(*args, **kwargs)
do you mean this in my .pth file
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 3)
self.fc1 = nn.Linear(64*14*14,30)
self.fc2 = nn.Linear(30, 3)
#self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # output size: [batch_size, 32, 255, 255]
x = self.pool(F.relu(self.conv2(x))) # output size: [batch_size, 64, 126, 126]
x = x.view(-1,64*14*14) # output size: [batch_size, 64*126*126]
x = F.relu(self.fc1(x))
x = self.fc2(x)
#x=self.softmax(x)
return x
x = torch.randn(1, 3, 64, 64) #(batch size or #of images,channels RGB,width,height)
model = Net()
output = model(x)
Yes exactly! You have to use the same as the on of the training.
If you used Net(), you’re code should be like
model = Net() model.load_state_dict(torch.load('last_brain1.pth')) model.eval()
this returning the output of whole network,not taking the given image and classifying it. returning whole network accuracy .NOt accuracy of single image that I have given.I think i m missing some module from file .
You did not only save the state_dict
. Therefore, you have to load it via model.load_state_dict(torch.load('last_brain1.pth')["state_dict"])
1 Like
Why not? Can you post the error message? Ideally without a Screenshot but copy-paste it?