import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets, transforms, models
from captum.insights import AttributionVisualizer, Batch
from captum.insights.features import ImageFeature
data_dir = "./BW"
test_dir=data_dir + '/Test'
def get_classes():
classes = [
"B",
"M",
]
return classes
# Define functions for classification classes and pretrained model.
def get_pretrained_model():
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
model.load_state_dict(torch.load('checkpoint_ResnetBW.pt'))
def baseline_func(input):
return input * 0
def formatted_data_iter():
dataset = datasets.ImageFolder(test_dir, transform=transforms.ToTensor()
)
dataloader = iter(
torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False, num_workers=2)
)
while True:
images, labels = next(dataloader)
yield Batch(inputs=images, labels=labels)
# Run the visualizer and render inside notebook for interactive debugging.
normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
model = get_pretrained_model()
visualizer = AttributionVisualizer(
models=[model],
score_func=lambda o: torch.nn.functional.softmax(o, 1),
classes=get_classes(),
features=[
ImageFeature(
"Photo",
baseline_transforms=[baseline_func],
input_transforms=[normalize],
)
],
dataset=formatted_data_iter(),
)
visualizer.render()
I am getting output like
Fetch data and view Captum Insights at http://localhost:64328/
<IPython.lib.display.IFrame object at 0x0000022B8CEA3248>
The link is like