I want to do a sensitivity analysis of a torch nn regression model. I understand how it works like this:
x1 = torch.randn(1, requires_grad=True)
x2 = torch.randn(1, requires_grad=True)
u = u(x1,x2)
u = 3 * x1 ** 3 - x2 ** 2
dx1 = torch.autograd.grad(u, x1)
dx2 = torch.autograd.grad(u, x2)
print(dx1 > dx2)
But how do I reference the input features (x1, x2) and the u (model), if I am using an NN module like this:
## initialize the layers
def init(self, x_means, x_deviations, y_means, y_deviations):
self.x_means = x_means
self.x_deviations = x_deviations
self.y_means = y_means
self.y_deviations = y_deviations
self.linear1 = nn.Linear(4, 2)
## perform inference
def forward(self, x):
x = (x - self.x_means) / self.x_deviations
y_scaled = self.linear1(x)
y_descaled = y_scaled * self.y_deviations + self.y_means
return y_descaled, y_scaled
is it something like this?
u = model()
dx1 = torch.autograd.grad(u, u.layers.input.x1)
dx2 = torch.autograd.grad(u, u.layers.inout.x2)
Can someone help me with this?
PyTorch modules do not store an
.input attribute and I assume
x2 are treated as trainable parameters in the first example? If so, they would correspond to e.g.
Thanks for your reply. I need the derivatives in terms of each feature such as:
du/dx1 and du/x2
so I need to access the input tensor from the computational graph. Is there another way to do that?
u.linear1.weight is a reference to the weights but I need the layer before that (i.e. the input)
I’m unsure what your exact use case is, but maybe
torch.autograd.grad is what you are looking for.
Thank you. Yes, torch.autograd is what I have been using. The use case is feature ranking. In theory, you can use the derivatives of the function with respect to each input. In a classic linear regression it is called sensitivity analysis. I just have not seen this done for neural nets and PyTorch.
The goal is to know which features are most important for the regression model prediction.
Would this approach of looping through the inputs to calculate the gradients (or using
jacobian) work for your use case?