Efficient computation of per-sample examples

Additional discussion of [feature request] Simple and Efficient way to get gradients of each element of a sum

@tom I hope I interpreted your last post correctly by posting here.

The code to reproduce the numbers I see is available here.
The main.py file runs different ways of computing individual gradients and checks/times them.

Your recommendation seems to be along the line of Goodfellow’s derivation, implemented as goodfellow, timing output under goodf.

I don’t think it is a bug per-se, just an inefficient use of things.
I am not sure what you want to look at but I’m happy to discuss the issue :slight_smile:

I’m never sure what to make of the archive paper you link.

So along the lines of Adam’s great suggestion for Jacobians in the bug, you could use

class LinearWithBatchGradFn(torch.autograd.Function):
    def forward(ctx, inp, weight, bias=None):
        ctx.save_for_backward(inp, weight, bias)
        return torch.nn.functional.linear(inp, weight, bias)
    def backward(ctx, grad_out):
        inp, weight, bias = ctx.saved_tensors
        grad_bias = grad_out if bias is not None else None
        return grad_out @ weight, (inp.unsqueeze(1)*grad_out.unsqueeze(2)), grad_bias
linear_with_batch_grad = LinearWithBatchGradFn.apply

Then instead of the usual

weight = torch.randn(3,2, requires_grad=True)
bias = torch.randn(3, requires_grad=True)
inp = torch.randn(4,2, requires_grad=True)
a = torch.nn.functional.linear(inp, weight, bias)
gradw = torch.randn(4,3)
gi, gw, gb = torch.autograd.grad((a*gradw).sum(), [inp, weight, bias])

you can do

a2 = linear_with_batch_grad(inp, weight, bias)
gi2, gw2, gb2 = torch.autograd.grad((a2*gradw).sum(), [inp, weight, bias])

and have the right thing:

print("grad weight", gw.shape, gw2.shape, torch.allclose(gw2.sum(0), gw))
print("grad bias", gb.shape, gb2.shape, torch.allclose(gb2.sum(0), gb))
print("grad inp stays the same for other layers networks", gi.shape, gi2.shape, torch.allclose(gi, gi2))

I think this should be about as efficient as you get.

As mentioned, there convolution backward is a bit more tricky if you cannot call backwards directly. One costly workaround would be to do the forwards at backward time again to get the backward.

Best regards