Efficient Per-Example Gradient Computations

If you just have conv2d + linear layers, you could do this using single backward pass using something like this – https://github.com/cybertronai/autograd-hacks/blob/master/autograd_hacks.py#L167

For norms squared, replace “grad1=einsum” line with torch.sum(B*B, dim=1)*torch.sum(A*A, dim=1)

2 Likes