I have a simple forward
method with two arguments each of which is a tensor of indices for embeddings:
def forward(self, arg1, arg2):
embed1 = self.embed1(arg1)
embed2 = self.embed2(arg2)
out = self.input_layer(embed1+embed2).squeeze()
out = self.output_layer(out)
return out.squeeze().detach().numpy()
arg1 = torch.tensor([ 62, 287, 22, 54, 7, 32, 80, 56, 14, 475])
arg2=torch.tensor([1, 0, 0, 0, 0, 0, 1, 1, 1, 1])
ig = IntegratedGradients(model)
ig.attribute((arg1, arg2))
I am getting the following error when my code tries to run the ig.attribute
function:
Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)
Anyone have an idea what might be going on?
EDIT:
I added a print statement inside the forward
function to print(arg1) and get a 500-dimensional tensor (which I assume is somehow the 10 indices each with their 50-dimensional embedding):
tensor([0.0000e+00, 5.6680e-04, 5.6680e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00,
5.6680e-04, 5.6680e-04, 0.0000e+00, 5.6680e-04, 0.0000e+00, 2.9840e-03,
2.9840e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.9840e-03, 2.9840e-03,
0.0000e+00, 2.9840e-03, 0.0000e+00, 7.3230e-03, 7.3230e-03, 0.0000e+00,
0.0000e+00, 0.0000e+00, 7.3230e-03, 7.3230e-03, 0.0000e+00, 7.3230e-03,
0.0000e+00, 1.3568e-02, 1.3568e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
1.3568e-02, 1.3568e-02, 0.0000e+00, 1.3568e-02, 0.0000e+00, 2.1695e-02,
2.1695e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.1695e-02, 2.1695e-02,
0.0000e+00, 2.1695e-02, 0.0000e+00, 3.1672e-02, 3.1672e-02, 0.0000e+00,
0.0000e+00, 0.0000e+00, 3.1672e-02, 3.1672e-02, 0.0000e+00, 3.1672e-02,
0.0000e+00, 4.3461e-02, 4.3461e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
4.3461e-02, 4.3461e-02, 0.0000e+00, 4.3461e-02, 0.0000e+00, 5.7016e-02,
5.7016e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 5.7016e-02, 5.7016e-02,
0.0000e+00, 5.7016e-02, 0.0000e+00, 7.2285e-02, 7.2285e-02, 0.0000e+00,
0.0000e+00, 0.0000e+00, 7.2285e-02, 7.2285e-02, 0.0000e+00, 7.2285e-02,
0.0000e+00, 8.9209e-02, 8.9209e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
8.9209e-02, 8.9209e-02, 0.0000e+00, 8.9209e-02, 0.0000e+00, 1.0772e-01,
1.0772e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0772e-01, 1.0772e-01,
0.0000e+00, 1.0772e-01, 0.0000e+00, 1.2775e-01, 1.2775e-01, 0.0000e+00,
0.0000e+00, 0.0000e+00, 1.2775e-01, 1.2775e-01, 0.0000e+00, 1.2775e-01,
0.0000e+00, 1.4922e-01, 1.4922e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
1.4922e-01, 1.4922e-01, 0.0000e+00, 1.4922e-01, 0.0000e+00, 1.7205e-01,
1.7205e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.7205e-01, 1.7205e-01,
0.0000e+00, 1.7205e-01, 0.0000e+00, 1.9615e-01, 1.9615e-01, 0.0000e+00,
0.0000e+00, 0.0000e+00, 1.9615e-01, 1.9615e-01, 0.0000e+00, 1.9615e-01,
0.0000e+00, 2.2142e-01, 2.2142e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
2.2142e-01, 2.2142e-01, 0.0000e+00, 2.2142e-01, 0.0000e+00, 2.4777e-01,
2.4777e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.4777e-01, 2.4777e-01,
0.0000e+00, 2.4777e-01, 0.0000e+00, 2.7510e-01, 2.7510e-01, 0.0000e+00,
0.0000e+00, 0.0000e+00, 2.7510e-01, 2.7510e-01, 0.0000e+00, 2.7510e-01,
0.0000e+00, 3.0329e-01, 3.0329e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
3.0329e-01, 3.0329e-01, 0.0000e+00, 3.0329e-01, 0.0000e+00, 3.3225e-01,
3.3225e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.3225e-01, 3.3225e-01,
0.0000e+00, 3.3225e-01, 0.0000e+00, 3.6186e-01, 3.6186e-01, 0.0000e+00,
0.0000e+00, 0.0000e+00, 3.6186e-01, 3.6186e-01, 0.0000e+00, 3.6186e-01,
0.0000e+00, 3.9200e-01, 3.9200e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
3.9200e-01, 3.9200e-01, 0.0000e+00, 3.9200e-01, 0.0000e+00, 4.2255e-01,
4.2255e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.2255e-01, 4.2255e-01,
0.0000e+00, 4.2255e-01, 0.0000e+00, 4.5341e-01, 4.5341e-01, 0.0000e+00,
0.0000e+00, 0.0000e+00, 4.5341e-01, 4.5341e-01, 0.0000e+00, 4.5341e-01,
0.0000e+00, 4.8445e-01, 4.8445e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
4.8445e-01, 4.8445e-01, 0.0000e+00, 4.8445e-01, 0.0000e+00, 5.1555e-01,
5.1555e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 5.1555e-01, 5.1555e-01,
0.0000e+00, 5.1555e-01, 0.0000e+00, 5.4659e-01, 5.4659e-01, 0.0000e+00,
0.0000e+00, 0.0000e+00, 5.4659e-01, 5.4659e-01, 0.0000e+00, 5.4659e-01,
0.0000e+00, 5.7745e-01, 5.7745e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
5.7745e-01, 5.7745e-01, 0.0000e+00, 5.7745e-01, 0.0000e+00, 6.0800e-01,
6.0800e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 6.0800e-01, 6.0800e-01,
0.0000e+00, 6.0800e-01, 0.0000e+00, 6.3814e-01, 6.3814e-01, 0.0000e+00,
0.0000e+00, 0.0000e+00, 6.3814e-01, 6.3814e-01, 0.0000e+00, 6.3814e-01,
0.0000e+00, 6.6775e-01, 6.6775e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
6.6775e-01, 6.6775e-01, 0.0000e+00, 6.6775e-01, 0.0000e+00, 6.9671e-01,
6.9671e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 6.9671e-01, 6.9671e-01,
0.0000e+00, 6.9671e-01, 0.0000e+00, 7.2490e-01, 7.2490e-01, 0.0000e+00,
0.0000e+00, 0.0000e+00, 7.2490e-01, 7.2490e-01, 0.0000e+00, 7.2490e-01,
0.0000e+00, 7.5223e-01, 7.5223e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
7.5223e-01, 7.5223e-01, 0.0000e+00, 7.5223e-01, 0.0000e+00, 7.7858e-01,
7.7858e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 7.7858e-01, 7.7858e-01,
0.0000e+00, 7.7858e-01, 0.0000e+00, 8.0385e-01, 8.0385e-01, 0.0000e+00,
0.0000e+00, 0.0000e+00, 8.0385e-01, 8.0385e-01, 0.0000e+00, 8.0385e-01,
0.0000e+00, 8.2795e-01, 8.2795e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
8.2795e-01, 8.2795e-01, 0.0000e+00, 8.2795e-01, 0.0000e+00, 8.5078e-01,
8.5078e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 8.5078e-01, 8.5078e-01,
0.0000e+00, 8.5078e-01, 0.0000e+00, 8.7225e-01, 8.7225e-01, 0.0000e+00,
0.0000e+00, 0.0000e+00, 8.7225e-01, 8.7225e-01, 0.0000e+00, 8.7225e-01,
0.0000e+00, 8.9228e-01, 8.9228e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
8.9228e-01, 8.9228e-01, 0.0000e+00, 8.9228e-01, 0.0000e+00, 9.1079e-01,
9.1079e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 9.1079e-01, 9.1079e-01,
0.0000e+00, 9.1079e-01, 0.0000e+00, 9.2771e-01, 9.2771e-01, 0.0000e+00,
0.0000e+00, 0.0000e+00, 9.2771e-01, 9.2771e-01, 0.0000e+00, 9.2771e-01,
0.0000e+00, 9.4298e-01, 9.4298e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
9.4298e-01, 9.4298e-01, 0.0000e+00, 9.4298e-01, 0.0000e+00, 9.5654e-01,
9.5654e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 9.5654e-01, 9.5654e-01,
0.0000e+00, 9.5654e-01, 0.0000e+00, 9.6833e-01, 9.6833e-01, 0.0000e+00,
0.0000e+00, 0.0000e+00, 9.6833e-01, 9.6833e-01, 0.0000e+00, 9.6833e-01,
0.0000e+00, 9.7831e-01, 9.7831e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
9.7831e-01, 9.7831e-01, 0.0000e+00, 9.7831e-01, 0.0000e+00, 9.8643e-01,
9.8643e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 9.8643e-01, 9.8643e-01,
0.0000e+00, 9.8643e-01, 0.0000e+00, 9.9268e-01, 9.9268e-01, 0.0000e+00,
0.0000e+00, 0.0000e+00, 9.9268e-01, 9.9268e-01, 0.0000e+00, 9.9268e-01,
0.0000e+00, 9.9702e-01, 9.9702e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
9.9702e-01, 9.9702e-01, 0.0000e+00, 9.9702e-01, 0.0000e+00, 9.9943e-01,
9.9943e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 9.9943e-01, 9.9943e-01,
0.0000e+00, 9.9943e-01], requires_grad=True) torch.Size([500])
I guess I don’t really understand what Captum is doing here. How should I write my forward
method to make sense for doing attribution?