I’m using gradient checkpointing (1.7.0.dev20200709) and I’m observing a behavior that I don’t understand.
Basically, I have a code snippet in my forward
that goes
foo = bar * baz
where bar
is requires_grad=False
and baz
is requires_grad=True
.
I call this code with the same inputs first outside, then inside a checkpoint()
call.
In the first case, foo
ends up being requires_grad=True
which is what I would expect.
In the second, foo
ends up being requires_grad=False
.
Is this expected? Will this not cause failure of the model to learn? Have I discovered a bug, or is this checkpoint
magic that will all turn out fine in the end?