Vectorized autograd.grad for class activation maps

I have a neural model f that produces a score per class. I’d like to compute something like Class Activation Maps by computing gradient of each output wrt input (and I don’t need to compute any gradients with respect to weights of f.

Is there a batched way to compute gradients wrt to every output value? If found torch.autograd.functional.jacobian — PyTorch 1.10.0 documentation, but it seems that it will do forward many times as well. Is it true? Does calling it like below make sense?

import torch
import torch.nn as nn

model = nn.Linear(10, 20).requires_grad_(False)

x = torch.zeros(5, 10).requires_grad_(True)

y = model(x)

# torch.Size([5, 10])

# torch.Size([5, 20])

grads = torch.autograd.functional.jacobian(lambda x: y, x, vectorize = True)
# torch.Size([5, 20, 5, 10])

Why does the grads.shape above contains batch dimension twice? Ideally, I’d like to get something like [5, 10, 20] where 5 is the batch dimension.

I understand it may make sense when the outputs depend on other elements of the batch (e.g. with a batchnorm), but when I don’t need this dependence, is there a way to break it? Should I sum over one of these batch dimensions?


This is because torch.autograd.functional.jacobian differentiates with respect to the whole Tensor, i.e. it includes the batch dim as a differentiable dimension of the input array (hence why you get the batch dim twice). You can then just take the diagonal over the pair of batch dimensions to get the per-sample grads you want! This can be done via torch.einsum

grads = torch.einsum("bibj->bij",grads)
1 Like

Should I sum over one of these batch dimensions?

Not really. If you function is actually independent across batches, then all the non “diagonal” elements will be 0s.
But if you have batchnorm or something like that, they won’t be 0 and you may not want to use them?

If you’re happy with summing, then you can get that result faster by expanding your weights to have a batch size and compute the gradients wrt to these expanded weights.