I have a model that uses gradient checkpointing (1.7.0.dev20200709) in its forward
method.
My training framework uses with torch.no_grad():
during evaluation/sample generation. This triggers the None of the inputs have requires_grad=True. Gradients will be None
warning. This is slightly annoying, but the worse thing is that it silences any subsequent, more meaningful occurrences of that warning message.
What is the clean way to solve this? I can make my checkpoint
calls contingent on self.training
since we also use eval()
but it seems to me a well-written module should be able to be called under no_grad
without causing warnings.
Maybe there’s a way to specifically detect that you’re under no_grad
- but if so, shouldn’t checkpoint
suppress the warning in that case?