I just had a quick question to ask about
torch.utils.checkpoint.checkpoint and how it relates to
torch.autograd.grad. Within the docs it states that
torch.utils.checkpoint.checkpoint doesn’t work with
torch.autograd.grad and only works with
Could I ask why that’s the case and if there’s any work-around? In my current use case I have a loss function which has a component that uses
torch.autograd.grad but I do
detach() those values and therefore don’t need to backprop through those values. However, when I use checkpointing I get the error stated within the docs.
RuntimeError: Checkpointing is not compatible with .grad(), please use .backward() if possible
Might it be possible to add the checkpointing after using
torch.autograd.grad? So, in my case, I have a loss function that contains
torch.autograd.grad terms but I do a trick that allows me to detach those values and backdrop through a different graph that gives me the same derivatives as if I were to backprop through the original
torch.autograd.grad graph. Might it be possible to add and remove checkpointing modules like module hooks? Or is checkpointing with
torch.autograd.grad completely out of the question even with the ‘trick’ I mentioned above?
Any help is appreciated! Thank you in advance!