GradCam visualisation on modified EfficientNet_b0

Hello, I am currently working on my thesis and I am working with medical images. I want to add some GradCam visualisation on the outcome of my model. I am using a pretrained EfficientNet_b0 with ‘features_only=True’ (timm library):

class EfficientNet(torch.nn.Module):
        def __init__(self):
            super().__init__()
            # base model
            self.feature_extractor = timm.create_model('efficientnet_b0', pretrained=True, features_only=True)
            # Get the number input features that the classifier receive. 
            # define the classifier. Note that the feature extractor keep 
            self.classification_label = nn.Sequential(
                nn.Linear(self.feature_extractor.feature_info.channels()[-1], 1280),
                torch.nn.ReLU(True),
                torch.nn.Dropout(),
                torch.nn.Linear(1280, 1280),
                torch.nn.ReLU(True),
                torch.nn.Dropout(),
                torch.nn.Linear(1280, 4) # the labels are 4 
                )
            
            self.classification_reason = nn.Sequential(
                nn.Linear(self.feature_extractor.feature_info.channels()[-1], 1280),
                torch.nn.ReLU(True),
                torch.nn.Dropout(),
                torch.nn.Linear(1280, 1280),
                torch.nn.ReLU(True),
                torch.nn.Dropout(),
                torch.nn.Linear(1280, 3) # the reasons are 3
                )
          
            
            self.flat_gap = torch.nn.Sequential(
                torch.nn.AdaptiveAvgPool2d(1),
                torch.nn.Flatten()
            )
        def forward(self, x):
            features = self.flat_gap(self.feature_extractor(x)[-1])
            label = self.classification_label(features)
            reason = self.classification_reason(features)
            
            return label, reason

My model gets the features of the last convolution layer and then forward passes them to two classifiers. One classifier for the diagnoses (labels) of the images and one classifier for the reasons for bad quality (bad_light, blurry, low_resolution) of the images.

After training my model, I want to load the saved weights and present a GradCam visualisation of the second classifier (for bad quality reasons). I want to get a heatmap representation for each reason. Can someone help me with the implementation?

Any help will be appreciated. Thank you in advance :slight_smile: