Checkpoint breaks grads

I’m training my text classification model and I’m using gradient checkpointing to save some memory. It looks like this:

...
embeddings = self.embedding_layer(
   x,
)
encoded = checkpoint.checkpoint(
    self.encoder,
    embeddings,
)
clf = self.linear(
    encoded,
)
...

But when I’m examining gradients of my model, I see that all gradients (except the last layer) are set to None:

for name, param in model.named_parameters():
            print(name, param.grad)

What is the problem and how can I solve it?
The most interesting thing - without checkpoint everything works fine.

@ptrblck you’re my last hope :frowning:

I am kind of new to gradient checkpointing so my questions may be silly

  • Did you check your gradient of parameters after running loss.backward() ?
  • Can you please share the elaborated code (If possible), so that I can run and try to see what is the issue ?