I’m never sure what to make of the archive paper you link.
So along the lines of Adam’s great suggestion for Jacobians in the bug, you could use
class LinearWithBatchGradFn(torch.autograd.Function):
@staticmethod
def forward(ctx, inp, weight, bias=None):
ctx.save_for_backward(inp, weight, bias)
return torch.nn.functional.linear(inp, weight, bias)
@staticmethod
def backward(ctx, grad_out):
inp, weight, bias = ctx.saved_tensors
grad_bias = grad_out if bias is not None else None
return grad_out @ weight, (inp.unsqueeze(1)*grad_out.unsqueeze(2)), grad_bias
linear_with_batch_grad = LinearWithBatchGradFn.apply
Then instead of the usual
weight = torch.randn(3,2, requires_grad=True)
bias = torch.randn(3, requires_grad=True)
inp = torch.randn(4,2, requires_grad=True)
a = torch.nn.functional.linear(inp, weight, bias)
gradw = torch.randn(4,3)
gi, gw, gb = torch.autograd.grad((a*gradw).sum(), [inp, weight, bias])
you can do
a2 = linear_with_batch_grad(inp, weight, bias)
gi2, gw2, gb2 = torch.autograd.grad((a2*gradw).sum(), [inp, weight, bias])
and have the right thing:
print("grad weight", gw.shape, gw2.shape, torch.allclose(gw2.sum(0), gw))
print("grad bias", gb.shape, gb2.shape, torch.allclose(gb2.sum(0), gb))
print("grad inp stays the same for other layers networks", gi.shape, gi2.shape, torch.allclose(gi, gi2))
I think this should be about as efficient as you get.
As mentioned, there convolution backward is a bit more tricky if you cannot call backwards directly. One costly workaround would be to do the forwards at backward time again to get the backward.
Best regards
Thomas