Normal inference works fine, but captum (IntegratedGradients) throws a size missmatch

I am trying to use captum for the frist time and would like to use integrated gradients a analyze my model.

Nomal inference works fine, but trying to use integrated gradients with captum leads to an error:

RuntimeError: size mismatch, m1: [1 x 9800], m2: [196 x 512] at /opt/conda/conda-bld/pytorch_1579022027550/work/aten/src/TH/generic/THTensorMath.cpp:136

I have no clue why the size is 1 x 9800.
The tensor I use has a single dimension with 192 elements and the output of the fist model layer is 512 elements.

This code I tried:

# Initial imports
from captum.attr import IntegratedGradients
from captum.attr import LayerConductance
from captum.attr import NeuronConductance

import torch
import torch.nn as nn

# Define model
class MLP(torch.nn.Module):
    def __init__(self):

        self.model = nn.Sequential(
            # Add input layer 
            nn.Linear(196, 512),
            # Add ReLU activation
            # Add Another layer
            nn.Linear(512, 512),
            # Add ReLU activation
            # Add Output layer
            nn.Linear(512, 12)

    def forward(self, x):
        # Forward pass
        return self.model(x)

# Prepare data and model
sample = torch.rand(196)
label = 7
model  = MLP()

# Normal inference (works fine):
score = model(sample)
prob = nn.functional.softmax(score, dim=0)
y_pred =  prob.argmax()
print("Predicted class {} with probability {}. True label is: {}".format(y_pred, prob[y_pred], label))

# Usage of captum (does not work)
ig = IntegratedGradients(model)
attr, delta = ig.attribute(sample,target=label, return_convergence_delta=True) # ERROR!
attr = attr.detach().numpy()

Has anyone an idea where the wrong dimension is coming from and how to solve this?


I edited the first post and provided a full minimum working example.

Any help would be appreciated :slight_smile:

I think you need to add a leading batch dimension to input. Additionally, internal_batch_size argument may be needed to limit cuda memory requirements.

1 Like

Thank you very much!
I added the line




and not I am not getting the error anymore :slight_smile: