Importance of NN input features

I have a simple NN for binary classification:

self.layer1  = nn.Linear(44, 128)
self.layer2  = nn.Linear(128, 512)
self.layer3  = nn.Linear(512, 1024)
self.layer4  = nn.Linear(1024, 512)
self.layer5  = nn.Linear(512, 128)
self.layer6  = nn.Linear(128, 40)
self.layer7  = nn.Linear(40, 1)

which gets 44 input features (some of them zeros) for training and outputs a classification score. It is trained sensibly and results in low loss (<0.1) and good classification power.

I would like to check which of those 44 features are most important i.e. how each one contributed to the minimisation of the loss function during training. I’ve heard of methods to rank input variables on a par with their relative contribution.

Can someone please explain how to implement this in PyTorch? A minimal code snippet which performs some computation based on trained_net.parameters() perhaps?

Many thanks in advance!

Just as a comment, I’ve come across this: neural networks - Deep learning : How do I know which variables are important? - Cross Validated and thought something along the lines of:

import numpy as np
params = []
sum_of_weights={el:0 for el in np.arange(45)}
i = 0
summarised_weight = 0
for param in trained_net.layer1.parameters():
    params.append(param.view(-1))
params = torch.cat(params)
params_list = params.tolist()

sum_of_weights={el:0 for el in np.arange(45)}
i = 0
summarised_weight = 0
for param in params_list:
    sum_of_weights[i%45] += param
    i += 1

would be a roughly valid implementation. This results in:

sum_of_weights

{0: 2.1881065031702747,
 1: -0.4291043917473871,
 2: -3.1886899566743523,
 3: -2.0855790052446537,
 4: 0.8537990134755091,
 5: -3.7174374876194634,
 6: -0.44004582360503264,
 7: -4.356606524073868,
 8: -1.028046389692463,
 9: -3.9685422069451306,
 10: -5.608152595435968,
 11: -3.4701311255339533,
 12: -1.8252374697985942,
 13: -0.7488467986695468,
 14: -8.28944860829506,
 15: -3.4464564417139627,
 16: 4.449272771860706,
 17: -4.691695329251161,
 18: -1.7501768926740624,
 19: 0.37166432512458414,
 20: -10.504940060782246,
 21: -0.6779894242063165,
 22: -6.289113475300837,
 23: -1.3348277652694378,
 24: -3.505854993709363,
 25: -0.9910737440804951,
 26: -12.102533261415374,
 27: -1.0398883295420092,
 28: -5.629827822733205,
 29: -6.627877121747588,
 30: -2.635934035963146,
 31: -2.0738988643570337,
 32: -3.8363308065308956,
 33: -2.3784317194658797,
 34: -1.377102676546201,
 35: -3.0660866306789103,
 36: 4.508294189465232,
 37: 2.8147964420095377,
 38: 0.7389443085994571,
 39: 2.061781210359186,
 40: 5.054453145865409,
 41: -5.041792882340815,
 42: -6.194136218487984,
 43: 0.17481421308184508,
 44: -9.710554008881445}

Any thoughts? Can some sense be made out of this?