from fastai.vision import *
from torchvision import datasets, transforms, models
import numpy as np
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
def conv_block(ni, nf, size=3, stride=1):
for_pad = lambda s: s if s > 2 else 3
return nn.Sequential(
nn.Conv2d(ni, nf, kernel_size=size, stride=stride,
padding=(for_pad(size) - 1)//2, bias=False),
nn.BatchNorm2d(nf),
nn.LeakyReLU(negative_slope=0.1, inplace=True)
)
def triple_conv(ni, nf):
return nn.Sequential(
conv_block(ni, nf),
conv_block(nf, ni, size=1),
conv_block(ni, nf)
)
def maxpooling():
return nn.MaxPool2d(2, stride=2)
model = nn.Sequential(
conv_block(3, 8),
maxpooling(),
conv_block(8, 16),
maxpooling(),
triple_conv(16, 32),
maxpooling(),
triple_conv(32, 64),
maxpooling(),
triple_conv(64, 128),
maxpooling(),
triple_conv(128, 256),
conv_block(256, 128, size=1),
conv_block(128, 256),
conv_layer(256, 2),
Flatten(),
nn.Linear(338, 2)
)
pretrained_dict = torch.load('trained_modelCov-N.pth',map_location=torch.device('cpu'))
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
model.eval()
data_dir = '/content/gdrive/My Drive/DCN(data)/Train/'
test_transforms = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
#transforms.Normalize(
#mean=[0.485, 0.456, 0.406],
#std=[0.229, 0.224, 0.225])
])
def predict_image(image):
image_tensor = test_transforms(image).float()
image_tensor = image_tensor.unsqueeze_(0)
input = Variable(image_tensor)
input = input.to(device)
output = model(input)
index = output.data.cpu().numpy().argmax()
return index
def get_random_images(num):
data = datasets.ImageFolder(data_dir, transform=test_transforms)
classes = data.classes
indices = list(range(len(data)))
np.random.shuffle(indices)
idx = indices[:num]
from torch.utils.data.sampler import SubsetRandomSampler
sampler = SubsetRandomSampler(idx)
loader = torch.utils.data.DataLoader(data, sampler=sampler, batch_size=num)
dataiter = iter(loader)
images, labels = dataiter.next()
return images, labels
data = datasets.ImageFolder(data_dir, transform=test_transforms)
classes=data.classes
indices=list(range(len(data)))
to_pil = transforms.ToPILImage()
images, labels = get_random_images(10)
fig=plt.figure(figsize=(40,50))
for ii in range(len(images)):
image = to_pil(images[ii])
index = predict_image(image)
sub = fig.add_subplot(1, len(images), ii+1)
res = int(labels[ii]) == index
sub.set_title(str(classes[index]) + ":" + str(res))
print(index)
plt.axis('off')
plt.imshow(image,cmap='binary')
plt.show()
Always obtain predict 0 in all images