Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)

I have a simple forward method with two arguments each of which is a tensor of indices for embeddings:

def forward(self, arg1, arg2):
        embed1 = self.embed1(arg1)
        embed2 = self.embed2(arg2)
        out = self.input_layer(embed1+embed2).squeeze()
        out = self.output_layer(out)
        return out.squeeze().detach().numpy()

arg1 = torch.tensor([ 62, 287,  22,  54,   7,  32,  80,  56,  14, 475])
arg2=torch.tensor([1, 0, 0, 0, 0, 0, 1, 1, 1, 1])
ig = IntegratedGradients(model)
ig.attribute((arg1, arg2))

I am getting the following error when my code tries to run the ig.attribute function:

Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)

Anyone have an idea what might be going on?

EDIT:

I added a print statement inside the forward function to print(arg1) and get a 500-dimensional tensor (which I assume is somehow the 10 indices each with their 50-dimensional embedding):

       tensor([0.0000e+00, 5.6680e-04, 5.6680e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        5.6680e-04, 5.6680e-04, 0.0000e+00, 5.6680e-04, 0.0000e+00, 2.9840e-03,
        2.9840e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.9840e-03, 2.9840e-03,
        0.0000e+00, 2.9840e-03, 0.0000e+00, 7.3230e-03, 7.3230e-03, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 7.3230e-03, 7.3230e-03, 0.0000e+00, 7.3230e-03,
        0.0000e+00, 1.3568e-02, 1.3568e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.3568e-02, 1.3568e-02, 0.0000e+00, 1.3568e-02, 0.0000e+00, 2.1695e-02,
        2.1695e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.1695e-02, 2.1695e-02,
        0.0000e+00, 2.1695e-02, 0.0000e+00, 3.1672e-02, 3.1672e-02, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 3.1672e-02, 3.1672e-02, 0.0000e+00, 3.1672e-02,
        0.0000e+00, 4.3461e-02, 4.3461e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        4.3461e-02, 4.3461e-02, 0.0000e+00, 4.3461e-02, 0.0000e+00, 5.7016e-02,
        5.7016e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 5.7016e-02, 5.7016e-02,
        0.0000e+00, 5.7016e-02, 0.0000e+00, 7.2285e-02, 7.2285e-02, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 7.2285e-02, 7.2285e-02, 0.0000e+00, 7.2285e-02,
        0.0000e+00, 8.9209e-02, 8.9209e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        8.9209e-02, 8.9209e-02, 0.0000e+00, 8.9209e-02, 0.0000e+00, 1.0772e-01,
        1.0772e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0772e-01, 1.0772e-01,
        0.0000e+00, 1.0772e-01, 0.0000e+00, 1.2775e-01, 1.2775e-01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.2775e-01, 1.2775e-01, 0.0000e+00, 1.2775e-01,
        0.0000e+00, 1.4922e-01, 1.4922e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.4922e-01, 1.4922e-01, 0.0000e+00, 1.4922e-01, 0.0000e+00, 1.7205e-01,
        1.7205e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.7205e-01, 1.7205e-01,
        0.0000e+00, 1.7205e-01, 0.0000e+00, 1.9615e-01, 1.9615e-01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.9615e-01, 1.9615e-01, 0.0000e+00, 1.9615e-01,
        0.0000e+00, 2.2142e-01, 2.2142e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        2.2142e-01, 2.2142e-01, 0.0000e+00, 2.2142e-01, 0.0000e+00, 2.4777e-01,
        2.4777e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.4777e-01, 2.4777e-01,
        0.0000e+00, 2.4777e-01, 0.0000e+00, 2.7510e-01, 2.7510e-01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 2.7510e-01, 2.7510e-01, 0.0000e+00, 2.7510e-01,
        0.0000e+00, 3.0329e-01, 3.0329e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        3.0329e-01, 3.0329e-01, 0.0000e+00, 3.0329e-01, 0.0000e+00, 3.3225e-01,
        3.3225e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.3225e-01, 3.3225e-01,
        0.0000e+00, 3.3225e-01, 0.0000e+00, 3.6186e-01, 3.6186e-01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 3.6186e-01, 3.6186e-01, 0.0000e+00, 3.6186e-01,
        0.0000e+00, 3.9200e-01, 3.9200e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        3.9200e-01, 3.9200e-01, 0.0000e+00, 3.9200e-01, 0.0000e+00, 4.2255e-01,
        4.2255e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.2255e-01, 4.2255e-01,
        0.0000e+00, 4.2255e-01, 0.0000e+00, 4.5341e-01, 4.5341e-01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 4.5341e-01, 4.5341e-01, 0.0000e+00, 4.5341e-01,
        0.0000e+00, 4.8445e-01, 4.8445e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        4.8445e-01, 4.8445e-01, 0.0000e+00, 4.8445e-01, 0.0000e+00, 5.1555e-01,
        5.1555e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 5.1555e-01, 5.1555e-01,
        0.0000e+00, 5.1555e-01, 0.0000e+00, 5.4659e-01, 5.4659e-01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 5.4659e-01, 5.4659e-01, 0.0000e+00, 5.4659e-01,
        0.0000e+00, 5.7745e-01, 5.7745e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        5.7745e-01, 5.7745e-01, 0.0000e+00, 5.7745e-01, 0.0000e+00, 6.0800e-01,
        6.0800e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 6.0800e-01, 6.0800e-01,
        0.0000e+00, 6.0800e-01, 0.0000e+00, 6.3814e-01, 6.3814e-01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 6.3814e-01, 6.3814e-01, 0.0000e+00, 6.3814e-01,
        0.0000e+00, 6.6775e-01, 6.6775e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        6.6775e-01, 6.6775e-01, 0.0000e+00, 6.6775e-01, 0.0000e+00, 6.9671e-01,
        6.9671e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 6.9671e-01, 6.9671e-01,
        0.0000e+00, 6.9671e-01, 0.0000e+00, 7.2490e-01, 7.2490e-01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 7.2490e-01, 7.2490e-01, 0.0000e+00, 7.2490e-01,
        0.0000e+00, 7.5223e-01, 7.5223e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        7.5223e-01, 7.5223e-01, 0.0000e+00, 7.5223e-01, 0.0000e+00, 7.7858e-01,
        7.7858e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 7.7858e-01, 7.7858e-01,
        0.0000e+00, 7.7858e-01, 0.0000e+00, 8.0385e-01, 8.0385e-01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 8.0385e-01, 8.0385e-01, 0.0000e+00, 8.0385e-01,
        0.0000e+00, 8.2795e-01, 8.2795e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        8.2795e-01, 8.2795e-01, 0.0000e+00, 8.2795e-01, 0.0000e+00, 8.5078e-01,
        8.5078e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 8.5078e-01, 8.5078e-01,
        0.0000e+00, 8.5078e-01, 0.0000e+00, 8.7225e-01, 8.7225e-01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 8.7225e-01, 8.7225e-01, 0.0000e+00, 8.7225e-01,
        0.0000e+00, 8.9228e-01, 8.9228e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        8.9228e-01, 8.9228e-01, 0.0000e+00, 8.9228e-01, 0.0000e+00, 9.1079e-01,
        9.1079e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 9.1079e-01, 9.1079e-01,
        0.0000e+00, 9.1079e-01, 0.0000e+00, 9.2771e-01, 9.2771e-01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 9.2771e-01, 9.2771e-01, 0.0000e+00, 9.2771e-01,
        0.0000e+00, 9.4298e-01, 9.4298e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        9.4298e-01, 9.4298e-01, 0.0000e+00, 9.4298e-01, 0.0000e+00, 9.5654e-01,
        9.5654e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 9.5654e-01, 9.5654e-01,
        0.0000e+00, 9.5654e-01, 0.0000e+00, 9.6833e-01, 9.6833e-01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 9.6833e-01, 9.6833e-01, 0.0000e+00, 9.6833e-01,
        0.0000e+00, 9.7831e-01, 9.7831e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        9.7831e-01, 9.7831e-01, 0.0000e+00, 9.7831e-01, 0.0000e+00, 9.8643e-01,
        9.8643e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 9.8643e-01, 9.8643e-01,
        0.0000e+00, 9.8643e-01, 0.0000e+00, 9.9268e-01, 9.9268e-01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 9.9268e-01, 9.9268e-01, 0.0000e+00, 9.9268e-01,
        0.0000e+00, 9.9702e-01, 9.9702e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        9.9702e-01, 9.9702e-01, 0.0000e+00, 9.9702e-01, 0.0000e+00, 9.9943e-01,
        9.9943e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 9.9943e-01, 9.9943e-01,
        0.0000e+00, 9.9943e-01], requires_grad=True) torch.Size([500])

I guess I don’t really understand what Captum is doing here. How should I write my forward method to make sense for doing attribution?

I really am not understanding what is happening even with the simple toy model in the documentation:

class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(3, 3)
        self.relu = nn.ReLU()
        self.lin2 = nn.Linear(3, 2)

        # initialize weights and biases
        self.lin1.weight = nn.Parameter(torch.arange(-4.0, 5.0).view(3, 3))
        self.lin1.bias = nn.Parameter(torch.zeros(1,3))
        self.lin2.weight = nn.Parameter(torch.arange(-3.0, 3.0).view(2, 3))
        self.lin2.bias = nn.Parameter(torch.ones(1,2))

    def forward(self, input):
        print(input, input.shape)
        return self.lin2(self.relu(self.lin1(input)))

if __name__=="__main__":
    model = ToyModel()
    model.eval()

    input = torch.rand(2, 3)
    baseline = torch.zeros(2, 3)   

    ig = IntegratedGradients(model)
    print(input)
    attributions, delta = ig.attribute(input, baseline, target=0, return_convergence_delta=True)
    print('IG Attributions:', attributions)
    print('Convergence Delta:', delta)

I added a print statement inside the forward method. The output of running this:

tensor([[0.2980, 0.8050, 0.3707],
        [0.9142, 0.3521, 0.9029]])
tensor([[1.6890e-04, 4.5625e-04, 2.1012e-04],
        [5.1817e-04, 1.9954e-04, 5.1177e-04],
        [8.8923e-04, 2.4020e-03, 1.1062e-03],
        [2.7280e-03, 1.0505e-03, 2.6943e-03],
        [2.1822e-03, 5.8947e-03, 2.7147e-03],
        [6.6947e-03, 2.5781e-03, 6.6120e-03],
        [4.0432e-03, 1.0921e-02, 5.0298e-03],
        [1.2404e-02, 4.7766e-03, 1.2250e-02],
        [6.4649e-03, 1.7463e-02, 8.0424e-03],
        [1.9833e-02, 7.6376e-03, 1.9588e-02],
        [9.4381e-03, 2.5494e-02, 1.1741e-02],
        [2.8954e-02, 1.1150e-02, 2.8597e-02],
        [1.2951e-02, 3.4984e-02, 1.6111e-02],
        [3.9732e-02, 1.5301e-02, 3.9241e-02],
        [1.6991e-02, 4.5895e-02, 2.1137e-02],
        [5.2124e-02, 2.0073e-02, 5.1480e-02],
        [2.1541e-02, 5.8186e-02, 2.6797e-02],
        [6.6083e-02, 2.5448e-02, 6.5267e-02],
        [2.6584e-02, 7.1809e-02, 3.3071e-02],
        [8.1555e-02, 3.1406e-02, 8.0547e-02],
        [3.2101e-02, 8.6711e-02, 3.9934e-02],
        [9.8480e-02, 3.7924e-02, 9.7263e-02],
        [3.8070e-02, 1.0284e-01, 4.7360e-02],
        [1.1679e-01, 4.4976e-02, 1.1535e-01],
        [4.4468e-02, 1.2012e-01, 5.5319e-02],
        [1.3642e-01, 5.2535e-02, 1.3474e-01],
        [5.1271e-02, 1.3849e-01, 6.3782e-02],
        [1.5729e-01, 6.0572e-02, 1.5535e-01],
        [5.8452e-02, 1.5789e-01, 7.2715e-02],
        [1.7932e-01, 6.9055e-02, 1.7710e-01],
        [6.5983e-02, 1.7823e-01, 8.2083e-02],
        [2.0242e-01, 7.7952e-02, 1.9992e-01],
        [7.3835e-02, 1.9944e-01, 9.1852e-02],
        [2.2651e-01, 8.7229e-02, 2.2371e-01],
        [8.1978e-02, 2.2144e-01, 1.0198e-01],
        [2.5149e-01, 9.6849e-02, 2.4839e-01],
        [9.0381e-02, 2.4414e-01, 1.1243e-01],
        [2.7727e-01, 1.0678e-01, 2.7385e-01],
        [9.9010e-02, 2.6745e-01, 1.2317e-01],
        [3.0374e-01, 1.1697e-01, 2.9999e-01],
        [1.0783e-01, 2.9128e-01, 1.3414e-01],
        [3.3081e-01, 1.2739e-01, 3.2672e-01],
        [1.1681e-01, 3.1554e-01, 1.4532e-01],
        [3.5836e-01, 1.3800e-01, 3.5394e-01],
        [1.2592e-01, 3.4014e-01, 1.5665e-01],
        [3.8630e-01, 1.4876e-01, 3.8153e-01],
        [1.3512e-01, 3.6498e-01, 1.6809e-01],
        [4.1451e-01, 1.5963e-01, 4.0939e-01],
        [1.4437e-01, 3.8996e-01, 1.7959e-01],
        [4.4289e-01, 1.7055e-01, 4.3741e-01],
        [1.5363e-01, 4.1499e-01, 1.9112e-01],
        [4.7132e-01, 1.8150e-01, 4.6549e-01],
        [1.6288e-01, 4.3998e-01, 2.0263e-01],
        [4.9969e-01, 1.9243e-01, 4.9352e-01],
        [1.7208e-01, 4.6482e-01, 2.1407e-01],
        [5.2790e-01, 2.0329e-01, 5.2138e-01],
        [1.8118e-01, 4.8942e-01, 2.2539e-01],
        [5.5584e-01, 2.1405e-01, 5.4897e-01],
        [1.9017e-01, 5.1368e-01, 2.3657e-01],
        [5.8339e-01, 2.2466e-01, 5.7619e-01],
        [1.9899e-01, 5.3751e-01, 2.4754e-01],
        [6.1046e-01, 2.3508e-01, 6.0292e-01],
        [2.0762e-01, 5.6082e-01, 2.5828e-01],
        [6.3693e-01, 2.4528e-01, 6.2906e-01],
        [2.1602e-01, 5.8351e-01, 2.6873e-01],
        [6.6271e-01, 2.5521e-01, 6.5452e-01],
        [2.2416e-01, 6.0551e-01, 2.7886e-01],
        [6.8769e-01, 2.6483e-01, 6.7919e-01],
        [2.3202e-01, 6.2672e-01, 2.8863e-01],
        [7.1178e-01, 2.7410e-01, 7.0299e-01],
        [2.3955e-01, 6.4706e-01, 2.9800e-01],
        [7.3488e-01, 2.8300e-01, 7.2580e-01],
        [2.4673e-01, 6.6646e-01, 3.0693e-01],
        [7.5691e-01, 2.9148e-01, 7.4756e-01],
        [2.5353e-01, 6.8484e-01, 3.1539e-01],
        [7.7778e-01, 2.9952e-01, 7.6817e-01],
        [2.5993e-01, 7.0212e-01, 3.2335e-01],
        [7.9741e-01, 3.0708e-01, 7.8756e-01],
        [2.6590e-01, 7.1824e-01, 3.3078e-01],
        [8.1572e-01, 3.1413e-01, 8.0564e-01],
        [2.7141e-01, 7.3315e-01, 3.3764e-01],
        [8.3265e-01, 3.2065e-01, 8.2236e-01],
        [2.7646e-01, 7.4677e-01, 3.4392e-01],
        [8.4812e-01, 3.2661e-01, 8.3764e-01],
        [2.8101e-01, 7.5906e-01, 3.4958e-01],
        [8.6208e-01, 3.3198e-01, 8.5143e-01],
        [2.8505e-01, 7.6997e-01, 3.5460e-01],
        [8.7447e-01, 3.3675e-01, 8.6367e-01],
        [2.8856e-01, 7.7946e-01, 3.5897e-01],
        [8.8525e-01, 3.4090e-01, 8.7431e-01],
        [2.9153e-01, 7.8749e-01, 3.6267e-01],
        [8.9437e-01, 3.4442e-01, 8.8332e-01],
        [2.9396e-01, 7.9403e-01, 3.6568e-01],
        [9.0180e-01, 3.4728e-01, 8.9066e-01],
        [2.9582e-01, 7.9906e-01, 3.6800e-01],
        [9.0751e-01, 3.4948e-01, 8.9630e-01],
        [2.9711e-01, 8.0255e-01, 3.6961e-01],
        [9.1147e-01, 3.5100e-01, 9.0021e-01],
        [2.9783e-01, 8.0450e-01, 3.7050e-01],
        [9.1368e-01, 3.5185e-01, 9.0240e-01]], requires_grad=True) torch.Size([100, 3])
tensor([[0., 0., 0.],
        [0., 0., 0.]]) torch.Size([2, 3])
tensor([[0.2980, 0.8050, 0.3707],
        [0.9142, 0.3521, 0.9029]]) torch.Size([2, 3])
IG Attributions: tensor([[ 0.0000, -2.4149, -2.2243],
        [-1.8284, -1.0562, -3.6116]], dtype=torch.float64)
Convergence Delta: tensor([-3.5763e-07,  5.9605e-08], dtype=torch.float64)

Where the heck is this 100x3 tensor coming from?

I have encountered the same problem.

The 100x3 tensor come from the n_steps argument of IntegratedGradients.attribute.
This parameter is set to 50 by default, with your 2x3 input, you get 100x3.

I encountered the same error, using additional_forward_args worked for me :

ig = IntegratedGradients(model)
ig.attribute(arg1, 
additional_forward_args=arg2)