I realise the title sounds dangerous.
I’d like to do something akin to the following:
a = torch.rand(1, requires_grad=True) b = torch.rand(2, requires_grad=True) c = torch.rand(3, requires_grad=True) x = torch.rand(1, requires_grad=True) y = torch.empty(4) y = x y = y[:1].dot(a) y = y[:2].dot(b) y = y[:3].dot(c) y.sum().backward()
c don’t require gradients then this works fine. However if they do then an error is thrown, because
y is needed to calculate the gradients wrt
c, and it’s been modified in-place.
Except - it hasn’t! The only bits that have been modified are the bits that don’t affect the gradient calculation. Not that I expect the correctness checks to be able to detect that level of detail.
The obvious workaround is to split up
y into lots of little tensors, but that would make each dot product slower, as we’d have to drop back into Python to do the dot product. This is inside a hot loop that I really want to be as optimised as possible.
Are there any clever workarounds that I might be able to pull, or any way I can turn off the hand-holding of the version tracker? (I already tried modifying the version after each assignment but apparently it’s read-only.)