How to allow in-place modification of saved tensors from forward pass in backward pass?

I know saved tensors should not be modified because they could be used multiple times.
However, when we save the output result of some layer, once the backward pass reaches that layer, there should not be any other layer that’d use it (right?) because the backward pass has already gone through them before. i.e., the output of A used in later B, C, etc., layers, so B, C, etc., should not modify it in place. Only A should be able to.

The problem PyTorch won’t allow any modification, period.

Why is this useful? We could reuse that Tensor for other calculations and save precious memory by doing it in place.

I don’t think that’s true, as this code snippet shows that inplace manipulations are allowed after the computation graph was freed and thus no gradients can be (wrongly) calculated anymore:

# setup
lin1 = nn.Linear(1, 1)
lin2 = nn.Linear(1, 1)
x = torch.randn(1, 1)

# disallowed since out1 is needed to calculate the gradients
out1 = lin1(x)
out2 = lin2(out1)
out1[0] += 1
out2.backward()

# manipulate afterwards
out1 = lin1(x)
out2 = lin2(out1)
out2.backward()
out1[0] += 1 # works
1 Like

I see. What I want to do is something like this:

class AFunc(Function):
    @staticmethod
    def forward(ctx, x):
        res = function(x)
        ctx.save_for_backward(x, res)
        return res

    @staticmethod
    def backward(ctx, grad_output):
        x, res = ctx.saved_tensors
        res.inplace_op_(x)
        return res

x = torch.rand(1,1)
res = SomeFunc.apply(x)
out = use_res_func(res) 
out.backward()

So I want to modify during the backward pass while the computation graph is not (entirely?) freed.
Since I want to modify res and not x, it’d be fine because res is created in the forward function, and when backward is reached, no other node in the remaining graph will use res anyhow.
How do I tell PyTorch that it is okay to modify res in the backward function?