Assume we have a batch of data. Given each data point in the batch, I would like to get the norm of the gradient of the output of the network w.r.t the network parameters.
A naive solution is to perform multiple forward and backward passes (one pass for one data point), but this is pretty slow. Is there a better solution?
A better solution seems possible, but I don’t see any mention of it in Pytorch documentation.
A naive solution is to perform multiple forward and backward passes (one pass for one data point), but this is pretty slow.
Do you mean to feed the data points within a for-loop one by one? I don’t think this will be right because PyTorch is doing accumulation on the gradients.
If you just have conv2d + linear layers, you could do this using single backward pass using something like this – https://github.com/cybertronai/autograd-hacks/blob/master/autograd_hacks.py#L167
For norms squared, replace “grad1=einsum” line with
torch.sum(B*B, dim=1)*torch.sum(A*A, dim=1)
You can do per example gradients using torch.autograd.grad. The key trick is to manually broadcast the weights yourself, and then differentiate with respect to the broadcasted weight. Here’s a very simple example on a model that is just a Linear:
B = 4
N = 2
M = 3
input = torch.randn(B, 1, N)
weights = torch.randn(N, M, requires_grad=True)
broadcasted_weights = weights.expand(B, N, M)
output = torch.bmm(input, broadcasted_weights)
loss = output.sum()
I don’t know if you can conveniently do this with torch.nn modules, might be worth thinking about how to expose this in the API.
Note this doesn’t work if you have operations in your network that implicitly broadcast, but don’t support taking broadcasted input. Convolution is an example of this.
Simple API would be a helper function that takes an nn.Module and returns a version that implements the trick above. IE, model = batch_expand(model, N)
Autograd memory usage in new model has to be reasonable, otherwise there’s no benefit over original model and batch-size=1.
This seems to be the case in the example above (colab measurement forum - replication memory - simplified)