# Efficient computation of per-sample examples

Additional discussion of [feature request] Simple and Efficient way to get gradients of each element of a sum

@tom I hope I interpreted your last post correctly by posting here.

The code to reproduce the numbers I see is available here.
The `main.py` file runs different ways of computing individual gradients and checks/times them.

Your recommendation seems to be along the line of Goodfellow’s derivation, implemented as `goodfellow`, timing output under `goodf`.

I don’t think it is a `bug` per-se, just an inefficient use of things.
I am not sure what you want to look at but I’m happy to discuss the issue

I’m never sure what to make of the archive paper you link.

So along the lines of Adam’s great suggestion for Jacobians in the bug, you could use

``````class LinearWithBatchGradFn(torch.autograd.Function):
@staticmethod
def forward(ctx, inp, weight, bias=None):
ctx.save_for_backward(inp, weight, bias)
return torch.nn.functional.linear(inp, weight, bias)
@staticmethod
def backward(ctx, grad_out):
inp, weight, bias = ctx.saved_tensors
grad_bias = grad_out if bias is not None else None
return grad_out @ weight, (inp.unsqueeze(1)*grad_out.unsqueeze(2)), grad_bias
linear_with_batch_grad = LinearWithBatchGradFn.apply
``````

Then instead of the usual

``````weight = torch.randn(3,2, requires_grad=True)
bias = torch.randn(3, requires_grad=True)
inp = torch.randn(4,2, requires_grad=True)
a = torch.nn.functional.linear(inp, weight, bias)
gradw = torch.randn(4,3)
gi, gw, gb = torch.autograd.grad((a*gradw).sum(), [inp, weight, bias])
``````

you can do

``````a2 = linear_with_batch_grad(inp, weight, bias)
gi2, gw2, gb2 = torch.autograd.grad((a2*gradw).sum(), [inp, weight, bias])
``````

and have the right thing:

``````print("grad weight", gw.shape, gw2.shape, torch.allclose(gw2.sum(0), gw))
print("grad bias", gb.shape, gb2.shape, torch.allclose(gb2.sum(0), gb))
print("grad inp stays the same for other layers networks", gi.shape, gi2.shape, torch.allclose(gi, gi2))
``````

I think this should be about as efficient as you get.

As mentioned, there convolution backward is a bit more tricky if you cannot call backwards directly. One costly workaround would be to do the forwards at backward time again to get the backward.

Best regards

Thomas

2 Likes