Compute the gradients of all logins with respect to input

Given a neural network classifier with 10 classes (the final layer logits have dimension 10). Is there a way to compute the gradients of each of the logit w.r.t. the input in one backward propagation?

The only way I know to do this is by setting a for-loop like this

input_grad = torch.zeros(10,input_size)
logit = NN(x)
for i in range(10):
   x.grad.zero_()
   logit[i].backward(retain_graph=True)
   input_grad[i] = x.grad

However, this for loop is much slower than one backprop. Another minor question: is it correct that the above for-loop has 10 times computational cost as a single backprop, if the network is large enough?

Could you try the following code and see if you get the same values?

logit = NN(x)
input_grad, = torch.autograd.grad(logit, x, torch.ones_like(logit)) #note the comma after input_grad

Could try something like this to directly compare actually,

input_grad = torch.zeros(10,input_size)
logit = NN(x)
for i in range(10):
   x.grad.zero_()
   logit[i].backward(retain_graph=True)
   input_grad[i] = x.grad

logit_all = NN(x)
input_grad_all, = torch.autograd.grad(logit_all, x, torch.ones_like(logit_all))

print(torch.allclose(input_grad, input_grad_all))

Thank you for the reply!

Let’s say x is of size 5 * 3 * 32 * 32 (5 is the batch size) and num_classes = 10.
I think in your example, input_grad_all has the same size of the input x (both of them are 5 * 3 * 32 * 32). However, I would like to compute the gradients of each logit w.r.t. the input. That is, I am expecting the gradient to be of size 10 * 5 * 3 * 32 * 32.

Could you share a minimal reproducible example of NN? Just so there’s a complete example I can use to debug your problem.