Problem in 'Captum'

Captum is working without any error. But every time It is predicted as first class. Here is the latest snapshot of my code


import os
from captum.insights import AttributionVisualizer, Batch
from captum.insights.features import ImageFeature
from torchvision import datasets, transforms, models
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
test_transforms = transforms.Compose([
                    transforms.Resize((224,224)),
                    transforms.ToTensor(),
                    #transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010))
                    ])

def get_classes():
    classes = [
        "B",
        "M" 
    ]
    return classes


def get_pretrained_model():
    net= models.resnet18(pretrained=True)
    num_ftrs = net.fc.in_features
    net.fc = nn.Linear(num_ftrs, 2)
    pt_path = os.path.abspath(
        os.path.dirname(__file__) + "/models/ResnetBW.pt"
    )
    net.load_state_dict(torch.load(pt_path))
    return net


def baseline_func(input):
    return input * 0


def formatted_data_iter():
    dataset = datasets.ImageFolder(
        root="data/test", transform=test_transforms
    )
    dataloader = iter(
        torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False, num_workers=2)
    )
    while True:
        images, labels = next(dataloader)
        yield Batch(inputs=images, labels=labels)


if __name__ == "__main__":
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    model = get_pretrained_model()
    visualizer = AttributionVisualizer(
        models=[model],
        score_func=lambda o: torch.nn.functional.softmax(o, 1),
        classes=get_classes(),
        features=[
            ImageFeature(
                "Photo",
                baseline_transforms=[baseline_func],
                input_transforms=[normalize],
            )
        ],
        dataset=formatted_data_iter(),
    )

    visualizer.render(debug=True)

Are you getting other results, if you don’t use captum?
Could it be that your model just overfits class0?

No…There is no problem in model. I am testing the captum with the images which are predicted true class with this model.

Instead of class ‘B’ and "M’. If I use ‘N’ and ‘M’. The model is predicted all image as ‘M’ class.