`requires_grad` becomes False inside gradient checkpointing

I’m using gradient checkpointing (1.7.0.dev20200709) and I’m observing a behavior that I don’t understand.

Basically, I have a code snippet in my forward that goes
foo = bar * baz
where bar is requires_grad=False and baz is requires_grad=True.
I call this code with the same inputs first outside, then inside a checkpoint() call.
In the first case, foo ends up being requires_grad=True which is what I would expect.
In the second, foo ends up being requires_grad=False.

Is this expected? Will this not cause failure of the model to learn? Have I discovered a bug, or is this checkpoint magic that will all turn out fine in the end?


When you do checkpointing, the forward is called twice.
Once during the forward without tracking gradients (and so requires_grad=False) and once, during the backward, with tracking enabled to recover all the buffers.

1 Like

Checkpoint magic, got it :wink: