Error calculating integrated gradients with captum

Hi there, I’m trying to perform model interpretability with captum but running into an error. Specifically, it says:

I’m not certain how to resolve this. Here’s the definition of my model, for reference:

class dvib(nn.Module):
    def __init__(self,k,out_channels, hidden_size):
        super(dvib, self).__init__()
        self.conv = torch.nn.Conv2d(in_channels=1,
                            out_channels = out_channels,
                            kernel_size = (1,20),
        self.rnn = torch.nn.GRU(input_size = out_channels,  
                                hidden_size = hidden_size,
                                num_layers = 2,
                                bidirectional = True,
                                batch_first = True,
                                dropout = 0.2
        self.fc1 = nn.Linear(hidden_size*4, hidden_size*4)
        self.enc_mean = nn.Linear(hidden_size*4+578,k)
        self.enc_std = nn.Linear(hidden_size*4+578,k)
        self.dec = nn.Linear(k, 2)
        nn.init.constant_(self.fc1.bias, 0.0)
        nn.init.constant_(self.enc_mean.bias, 0.0)
        nn.init.constant_(self.enc_std.bias, 0.0)
        nn.init.constant_(self.dec.bias, 0.0)
    def cnn_gru(self,x,lens):
#         print(x.shape)
        x = x.unsqueeze(1)
#         print('after first unsqueeze: ', x.shape)
        x = self.conv(x)
#         print('after conv: ', x.shape)   
        x = torch.nn.ReLU()(x)
#         print('shape after relu: ', x.shape,type(x))
        x = x.squeeze(3)
#         print('shape after squeeze: ', x.shape)
#         x = x.view(x.size(0),-1)
        x = x.permute(0,2,1)
#         print('shape after permute: ', x.shape)
#         print(type(lens))
        gru_input = pack_padded_sequence(x,lens,batch_first=True, enforce_sorted=False)
        output, hidden = self.rnn(gru_input)
#         print('hidden layer: ', hidden.shape)
        output_all =[hidden[-1],hidden[-2],hidden[-3],hidden[-4]],dim=1)
#         print("output_all.shape:",output_all.shape)    
        return output_all
    def forward(self, pssm, lengths, FEGS): 
        cnn_vectors = self.cnn_gru(pssm, lengths)
        feature_vec =[cnn_vectors, FEGS], dim = 1)
        enc_mean, enc_std = self.enc_mean(feature_vec), f.softplus(self.enc_std(feature_vec)-5)
        eps = torch.randn_like(enc_std)
        latent = enc_mean + enc_std*eps
        outputs = f.sigmoid(self.dec(latent))
#         print(outputs.shape)

        return outputs, enc_mean, enc_std, latent

I load pretrained weights into the model as well, prior to passing it to captum with the relevant arguments:

ig = IntegratedGradients(model(test_pssm_small, test_len_small, test_FEGS_small))
attr = ig.attribute(test_FEGS_small, n_steps=5)