Efficient Per-Example Gradient Computations

Assume we have a batch of data. Given each data point in the batch, I would like to get the norm of the gradient of the output of the network w.r.t the network parameters.

A naive solution is to perform multiple forward and backward passes (one pass for one data point), but this is pretty slow. Is there a better solution?

A better solution seems possible, but I don’t see any mention of it in Pytorch documentation.

https://arxiv.org/abs/1510.01799

1 Like

Hi,

A naive solution is to perform multiple forward and backward passes (one pass for one data point), but this is pretty slow.

Do you mean to feed the data points within a for-loop one by one? I don’t think this will be right because PyTorch is doing accumulation on the gradients.

Thanks.

If you just have conv2d + linear layers, you could do this using single backward pass using something like this – https://github.com/cybertronai/autograd-hacks/blob/master/autograd_hacks.py#L167

For norms squared, replace “grad1=einsum” line with torch.sum(B*B, dim=1)*torch.sum(A*A, dim=1)

1 Like