Backprop through iterative process with in-place CUDA ops

Hi, this is a bit of an unusual scenario:

I have a CUDA kernel implementation of an autograd Function Func with custom backward pass which I use as follows during forward pass:

x_1 = from data loader
for t = 1 : N
x_(t + 1) = Func(x_t)

Where N can very large e.g. 100 and so to decrease GPU memory usage, Func performs operations on x in-place and sets the mark_dirty flag on it. During backward pass I would like to recompute the computation graph by basically trading compute for memory. Would I need to use checkpointing to achieve this goal or does setting mark_dirty flag on a Tensor automatically will recompute the input? Or said another way how does mark_dirty recomputes the lost Tensor, does it keep a copy in memory or computes it?

Thanks!

Hi,

mark_dirty does not recompute the Tensor for you in any ways. It only allows autograd to know what was modified inplace for error checking purposes.

Thanks. How do you suggest I can let autograd know about the in-place operation happening in Func?

This is the right way to let the autograd know about the inplace.
That’s actually the only think mark_dirty is doing: let the autograd know that an inplace op happened.

Sorry I was not clear earlier. How would autograd reconstruct original x for backprop, if it was sent to Func during forward pass and thus was overwritten?

No worries!

How would autograd reconstruct original x for backprop

It does not.
If you modified it inplace, and you save it, you will get the modified version during the backward pass.

For example:

import torch


class MyFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inp):
        print("fw in", inp)
        inp += 2
        print("fw out", inp)
        ctx.mark_dirty(inp)
        ctx.save_for_backward(inp)
        return inp

    @staticmethod
    def backward(ctx, grad_output):
        inp, = ctx.saved_tensors
        print("bw inp value", inp)
        return grad_output


a = torch.rand(2, requires_grad=True)

out = MyFn.apply(a.clone())

out.sum().backward()

Will print:

fw in tensor([0.3568, 0.4446], grad_fn=<CloneBackward0>)
fw out tensor([2.3568, 2.4446], grad_fn=<CloneBackward0>)
bw inp value tensor([2.3568, 2.4446], grad_fn=<MyFnBackward>)

As you can see the backward contains the modified value