Feature Importance of a Pytorch AutoEncoder

I need to get from my Pytorch AutoEncoder the importance it gives to each input variable. I am working with a tabular data set, no images.

My AutoEncoder is as follows:

class AE(torch.nn.Module):
    def __init__(self, input_size, hidden_layer, latent_layer):

        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(input_size, hidden_layer),
            torch.nn.Linear(hidden_layer, latent_layer)

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(latent_layer, hidden_layer),
            torch.nn.Linear(hidden_layer, input_size)

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

To save unnecessary information, I simply call the following function to get my model:

average_loss, model, train_losses, test_losses = fullAE(batch_size=128, input_size=genes_tensor.shape[1],
                                 learning_rate=0.0001, weight_decay=0,
                                 epochs=50, verbose=False, dataset=genes_tensor, betas_value=(0.9, 0.999), train_dataset=genes_tensor_train, test_dataset=genes_tensor_test)

Where “model” is a trained instance of the previous AutoEncoder:

model = AE(input_size=input_size, hidden_layer=int(input_size * 0.75), latent_layer=int(input_size * 0.5)).to(device)

Well now I need to get the importance given by that model to each input variable in my original “genes_tensor” dataset, but I don’t know how. I have researched how to do it and found a way to do it with shap software:

e = shap.DeepExplainer(model, genes_tensor)

shap_values = e.shap_values(


The problem with this implementation is the following: 1) I don’t know if what I am actually doing is correct. 2) It takes forever to finish, since the dataset contains 950 samples, I have tried to do it with only 1 sample and it takes long enough. The result using a single sample is as follows:

I have seen that there are other options to obtain the importance of the input variables like Captum, but Captum only allows to know the importance in Neural Networks with a single output neuron, in my case there are many.

The options for AEs or VAEs that I have seen on github do not work for me since they use concrete cases, and especially images always, for example:

Is my shap implementation correct, should I continue with shap even though it takes eons or should I switch to another method?

Thank you very much in advance.

Best regards.