Hi, I wonder what happens when calling backward on a network that contains a module that is not differentiable.
For example, if I have a conditional statement or argmax/argmin in the forward function of a module, which is used as a part of a larger network. Is the backward on that module automatically disabled? Or should I call detach() on the output of that module?
Usually this is taken care of.
What happens internally is that the backward of autograd Functions returns None for inputs that the result cannot be differentiated by.
To see it in action:
x = Variable(torch.randn(10), requires_grad=True)
y,l = x.max(0)
l.backward(retain_graph=True)
print(x.grad) # zero because dl/dx does not make sense
y.backward()
print(x.grad) # one at the maximum
Thank you a lot for the reply!
Which is the best practice currently suggested for cases similar to the one of the original poster?
Let’s say I want to implement a module with a forward of the type:
def forward(self, x):
if torch.allclose(x,self.known_input):
return self.know_output
else:
return self.mlp(x)
Is there a simple way to keep that conditional computation inside the module, just yielding zero gradients when doing the backward pass through it?
Well, PyTorch will give you gradients if you have them during your calculation and none if you don’t. So if what you plan to do makes mathematical sense, you can just follow it. Of course, if you only have fixed inputs, your output will not require_grad, in which case you need to skip the backward. But then, there isn’t much learning going on in that case, anyway.