Gradient checkpointing conflict with `no_grad()`

I have a model that uses gradient checkpointing (1.7.0.dev20200709) in its forward method.
My training framework uses with torch.no_grad(): during evaluation/sample generation. This triggers the None of the inputs have requires_grad=True. Gradients will be None warning. This is slightly annoying, but the worse thing is that it silences any subsequent, more meaningful occurrences of that warning message.

What is the clean way to solve this? I can make my checkpoint calls contingent on self.training since we also use eval() but it seems to me a well-written module should be able to be called under no_grad without causing warnings.
Maybe there’s a way to specifically detect that you’re under no_grad - but if so, shouldn’t checkpoint suppress the warning in that case?

Hi,

The way I would do it is to make the use of checkpoint conditional on the fact that at least one input does require gradients. If none of them do, there is no point to use the checkpoint.

That works, but it diminishes the value of the warning - if I make a mistake in some refactor and the inputs end up being requires_grad=False even for the training case, it will no longer warn me of that.

I don’t trust myself, so I think I’ll go with self.training - it may still trigger the warning incorrectly if my module is called in some different way that I’m not aware of (with no_grad without eval()), but I’d rather have too many warnings than too few.

1 Like