Assume we have a batch of data. Given each data point in the batch, I would like to get the norm of the gradient of the output of the network w.r.t the network parameters.

A naive solution is to perform multiple forward and backward passes (one pass for one data point), but this is pretty slow. Is there a better solution?

A better solution seems possible, but I don’t see any mention of it in Pytorch documentation.

https://arxiv.org/abs/1510.01799

2 Likes

Hi,

A naive solution is to perform multiple forward and backward passes (one pass for one data point), but this is pretty slow.

Do you mean to feed the data points within a for-loop one by one? I don’t think this will be right because PyTorch is doing accumulation on the gradients.

Thanks.

If you just have conv2d + linear layers, you could do this using single backward pass using something like this – https://github.com/cybertronai/autograd-hacks/blob/master/autograd_hacks.py#L167

For norms squared, replace “grad1=einsum” line with `torch.sum(B*B, dim=1)*torch.sum(A*A, dim=1)`

2 Likes

You can do per example gradients using torch.autograd.grad. The key trick is to manually broadcast the weights yourself, and then differentiate with respect to the broadcasted weight. Here’s a very simple example on a model that is just a Linear:

``````import torch

torch.manual_seed(0)

B = 4
N = 2
M = 3

input = torch.randn(B, 1, N)

loss = output.sum()

``````

I don’t know if you can conveniently do this with torch.nn modules, might be worth thinking about how to expose this in the API.

2 Likes

Note this doesn’t work if you have operations in your network that implicitly broadcast, but don’t support taking broadcasted input. Convolution is an example of this.

1 Like

Simple API would be a helper function that takes an nn.Module and returns a version that implements the trick above. IE, model = batch_expand(model, N)

Autograd memory usage in new model has to be reasonable, otherwise there’s no benefit over original model and batch-size=1.

This seems to be the case in the example above (colab measurement forum - replication memory - simplified)

1 Like