Disable in-place correctness version check / any other workaround

I realise the title sounds dangerous.

I’d like to do something akin to the following:

a = torch.rand(1, requires_grad=True)
b = torch.rand(2, requires_grad=True)
c = torch.rand(3, requires_grad=True)
x = torch.rand(1, requires_grad=True)
y = torch.empty(4)
y[0] = x
y[1] = y[:1].dot(a)
y[2] = y[:2].dot(b)
y[3] = y[:3].dot(c)
y.sum().backward()

If a, b, c don’t require gradients then this works fine. However if they do then an error is thrown, because y is needed to calculate the gradients wrt a, b, c, and it’s been modified in-place.

Except - it hasn’t! The only bits that have been modified are the bits that don’t affect the gradient calculation. Not that I expect the correctness checks to be able to detect that level of detail.

The obvious workaround is to split up y into lots of little tensors, but that would make each dot product slower, as we’d have to drop back into Python to do the dot product. This is inside a hot loop that I really want to be as optimised as possible.

Are there any clever workarounds that I might be able to pull, or any way I can turn off the hand-holding of the version tracker? (I already tried modifying the version after each assignment but apparently it’s read-only.)

Hi,

Yes this is an annoying limit of the inplace detection.
Going around the version counter is the only case where we still have to use .data. Unfortunately, .data has also the other effect of detaching the Tensor and so no gradient will be able to flow back through that computation.

The only way around that I can see is to wrap it into a custom Function. Where you use .data in the forward to hide the version changes.
But that means you have to write a custom backward for that :confused:

Note that a cleaner API to stop sharing version while preserving the autograd is in our todo list for the .data deprecation task here: https://github.com/pytorch/pytorch/issues/30987
But we didn’t had time to get around to do it yet.

Just because doing taxes and accounting makes me want to have a little who knows the most evil autograd tricks competition with @albanD and @ptrblck, here is my entry:

class DangerousDot(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b):
        ctx._a = a
        ctx._b = b
        return a.dot(b)
    @staticmethod
    def backward(ctx, grad_out):
        return grad_out * ctx._b, grad_out * ctx._a


a = torch.randn(10, requires_grad=True, dtype=torch.double)
b = torch.randn(10, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(DangerousDot.apply, (a, b))

a = torch.rand(1, requires_grad=True)
b = torch.rand(2, requires_grad=True)
c = torch.rand(3, requires_grad=True)
x = torch.rand(1, requires_grad=True)
y = torch.empty(4)
y[0] = x
y[1] = DangerousDot.apply(y[:1], a)
y[2] = DangerousDot.apply(y[:2], b)
y[3] = DangerousDot.apply(y[:3], c)
y.sum().backward()

Explanation: The bit that is responsible for the version check in autograd.Functions is ctx.save_for_backward(...) and ctx.saved_tensors. So avoiding that for a and b and assigning them to the ctx directly bypasses the checks.

Strictly don’t use this at home, it’s evil and dangerous!

Best regards

Thomas

1 Like

who knows the most evil autograd tricks competition

:smiley:

Note as well that tom’s code will create undestructible ref cycle and will leak memory on top of being very wrong :smiley:
The original purpose of save_for_backward() is actually to avoid this cycle :wink:

I would suggest this NotSoDangerousDot:

class NotSoDangerousDot(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b):
        prod = a * b
        ctx.save_for_backward(b)
        ctx.prod = prod
        return prod.sum()

    @staticmethod
    def backward(ctx, grad_out):
        b, = ctx.saved_tensors
        prod = ctx.prod
        return grad_out * b, grad_out * (prod / b)

Which might be slightly slower but should be correct in all ways! (cheating the autograd version checker is bad! haha)

Haha, thanks both! I ended up going down a very similar line to what you’ve just described, so here’s my entry to the competition:

class _AssignInAVerySafeManner(torch.autograd.Function):
    @staticmethod
    def forward(ctx, scratch, value, index):
        ctx.index = index
        scratch.data[index] = value
        return scratch

    @staticmethod
    def backward(ctx, grad_scratch):
        return grad_scratch, grad_scratch[ctx.index], None


a = torch.rand(1, requires_grad=True)
b = torch.rand(2, requires_grad=True)
c = torch.rand(3, requires_grad=True)
x = torch.rand((), requires_grad=True)
y = torch.empty(4)

y = _AssignInAVerySafeManner.apply(y, x, 0)
for i, letter in zip((1, 2, 3), (a, b, c)):
    y_ = y[:i].dot(letter)
    y = _AssignInAVerySafeManner.apply(y, y_, i)
y.sum().backward()

This does use the to-be-deprecated .data. though. Do you have any idea roughly when that’s likely to disappear?

1 Like

Touché! But you have NaNs if b is 0.

More seriously, I’d probably throw cloning.

Best regards

Thomas

1 Like