How do we interpret a multivariate forecasting model with Captum?

Dear pytorch community,

I am currently faced with an issue of not being able to carry out the evaluation of the importance of features of my multivariate forecasting model.

Here is the model architecture:

class MultivariateLSTM(nn.Module):
    
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(MultivariateLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size = input_size, hidden_size = hidden_size, num_layers = num_layers, batch_first = True, dropout = 0.2)
        self.fc1 = nn.Linear(hidden_size, 8)
        self.fc2 = nn.Linear(8, output_size)
        self.relu = nn.ReLU()
        # self.relu = SmeLU(6)
        
    def forward(self, x):
        
        out, _ = self.lstm(x)
        out = self.relu(out)
        out = self.fc1(out[:, -1, :])
        out = self.relu(out)
        out = self.fc2(out)
        
        return out

Here are the sample shapes used as inputs and outputs:

train_X_tensor.shape, train_y_tensor.shape = torch.Size([10339, 12, 3]) torch.Size([10339, 6])

Note that the structure of the input follows the (samples, seq_len, features) convention used in forecasting models.

Here is the code block used to evaluate the feature importance:

ig = IntegratedGradients(model)
ig_nt = NoiseTunnel(ig)
dl = DeepLift(model)
gs = GradientShap(model)
fa = FeatureAblation(model)

ig_attr_test = ig.attribute(test_X_tensor, n_steps=50)
ig_nt_attr_test = ig_nt.attribute(test_X_tensor)
dl_attr_test = dl.attribute(test_X_tensor)
gs_attr_test = gs.attribute(test_X_tensor, train_X_tensor)
fa_attr_test = fa.attribute(test_X_tensor)

and this is the current error:

----> 7 ig_attr_test = ig.attribute(test_X_tensor, n_steps=50)
AssertionError: Target not provided when necessary, cannot take gradient with respect to multiple outputs.

Does anyone have experience evaluating with Captum? I greatly appreciate if you could share your thoughts and experiences in this matter!

Thank you :slight_smile: