@ptrblck I’m having this issue since the first time I implemented it, and was not working before as my assumptions. The code that you see runs on torch 1.1.0 (BERTSUM main implementation), but I also tested it on BART Huggingface which uses Pytorch > 1.4.0.
Does pytorch version affect the checkpointing? I thought this before, and searched if gradient checkpointing has been added since a certain version of pytorch, but couldn’t find anything useful.
Another observation: what if the returned value has Gradient None? I’ve had this issue: Checkpoint with no grad requiring inputs PROBLEM as well, which is solved when the required_gradient
is set to True for the input arguments. I should mention that in HuggingFace’s BertModel
, there is src
argument that does not need gradients at all, but to fade the warning, I have to set its gradients to True. I this also an issue?