Why doesn't torch.utils.checkpoint.checkpoint work with torch.autograd.grad? Is there a work-around?

Hi All,

I just had a quick question to ask about torch.utils.checkpoint.checkpoint and how it relates to torch.autograd.grad. Within the docs it states that torch.utils.checkpoint.checkpoint doesn’t work with torch.autograd.grad and only works with torch.autograd.backward().

Could I ask why that’s the case and if there’s any work-around? In my current use case I have a loss function which has a component that uses torch.autograd.grad but I do detach() those values and therefore don’t need to backprop through those values. However, when I use checkpointing I get the error stated within the docs.

RuntimeError: Checkpointing is not compatible with .grad(), please use .backward() if possible

Might it be possible to add the checkpointing after using torch.autograd.grad? So, in my case, I have a loss function that contains torch.autograd.grad terms but I do a trick that allows me to detach those values and backdrop through a different graph that gives me the same derivatives as if I were to backprop through the original torch.autograd.grad graph. Might it be possible to add and remove checkpointing modules like module hooks? Or is checkpointing with torch.autograd.grad completely out of the question even with the ‘trick’ I mentioned above?

Any help is appreciated! Thank you in advance!

1 Like