So I read the inline documentation about mark_dirty() here:
I don’t quite understand what extra checks are needed for inplace operators. Would be great if the devs can give some hints. Thanks!
So I read the inline documentation about mark_dirty() here:
I don’t quite understand what extra checks are needed for inplace operators. Would be great if the devs can give some hints. Thanks!
If you are doing an in-place operation, and further operate on the original Tensor, the backward gradients might be wrong.
Let’s take a small example:
y = x^2 z = x^2
.
In this case, the gradient is 2x
.
So, the input
is needed to compute the gradients in the backward.
If we do all the operations out-of-place, we can hold onto the value of x
and it’s not a problem to compute correct gradients.
However, if we do the second operation in-place via: z = x.pow_(2)
, where x
is a Variable
, we cannot compute the backward pass of y = x^2
correctly.
on all Variable
s, we have an internal version counter to track these things, and mark_dirty
ensures that this version counter is correctly calculated.
If the user does an operation where the backward
cannot be correctly computed, then an error is thrown.
Thanks a lot for the explanation, Soumith