Understand mark_dirty()

So I read the inline documentation about mark_dirty() here:

I don’t quite understand what extra checks are needed for inplace operators. Would be great if the devs can give some hints. Thanks!

If you are doing an in-place operation, and further operate on the original Tensor, the backward gradients might be wrong.

Let’s take a small example:

y = x^2 z = x^2.

In this case, the gradient is 2x.
So, the input is needed to compute the gradients in the backward.

If we do all the operations out-of-place, we can hold onto the value of x and it’s not a problem to compute correct gradients.

However, if we do the second operation in-place via: z = x.pow_(2), where x is a Variable, we cannot compute the backward pass of y = x^2 correctly.

on all Variables, we have an internal version counter to track these things, and mark_dirty ensures that this version counter is correctly calculated.
If the user does an operation where the backward cannot be correctly computed, then an error is thrown.

1 Like

Thanks a lot for the explanation, Soumith :slight_smile: