Hi, this is a bit of an unusual scenario:
I have a CUDA kernel implementation of an autograd Function
Func with custom backward pass which I use as follows during forward pass:
x_1 = from data loader
for t = 1 : N
x_(t + 1) = Func(x_t)
N can very large e.g. 100 and so to decrease GPU memory usage,
Func performs operations on
x in-place and sets the mark_dirty flag on it. During backward pass I would like to recompute the computation graph by basically trading compute for memory. Would I need to use checkpointing to achieve this goal or does setting
mark_dirty flag on a Tensor automatically will recompute the input? Or said another way how does
mark_dirty recomputes the lost Tensor, does it keep a copy in memory or computes it?
mark_dirty does not recompute the Tensor for you in any ways. It only allows autograd to know what was modified inplace for error checking purposes.
Thanks. How do you suggest I can let autograd know about the in-place operation happening in
This is the right way to let the autograd know about the inplace.
That’s actually the only think mark_dirty is doing: let the autograd know that an inplace op happened.
Sorry I was not clear earlier. How would autograd reconstruct original
x for backprop, if it was sent to
Func during forward pass and thus was overwritten?
How would autograd reconstruct original
x for backprop
It does not.
If you modified it inplace, and you save it, you will get the modified version during the backward pass.
def forward(ctx, inp):
print("fw in", inp)
inp += 2
print("fw out", inp)
def backward(ctx, grad_output):
inp, = ctx.saved_tensors
print("bw inp value", inp)
a = torch.rand(2, requires_grad=True)
out = MyFn.apply(a.clone())
fw in tensor([0.3568, 0.4446], grad_fn=<CloneBackward0>)
fw out tensor([2.3568, 2.4446], grad_fn=<CloneBackward0>)
bw inp value tensor([2.3568, 2.4446], grad_fn=<MyFnBackward>)
As you can see the backward contains the modified value