Performing backward on a network with non-differentiable module

Hi, I wonder what happens when calling backward on a network that contains a module that is not differentiable.

For example, if I have a conditional statement or argmax/argmin in the forward function of a module, which is used as a part of a larger network. Is the backward on that module automatically disabled? Or should I call detach() on the output of that module?

Thanks a lot in advance.

Usually this is taken care of.
What happens internally is that the backward of autograd Functions returns None for inputs that the result cannot be differentiated by.
To see it in action:

x = Variable(torch.randn(10), requires_grad=True)
y,l = x.max(0)
print(x.grad) # zero because dl/dx does not make sense
print(x.grad) # one at the maximum

Best regards



With version 1.0 of Pytorch, the example provided by Thomas is not working. It throws:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
1 Like

Yes, Variable is no more and we actively keep people from expecting integer valued functions to have gradients.

Best regards


Thank you a lot for the reply!
Which is the best practice currently suggested for cases similar to the one of the original poster?
Let’s say I want to implement a module with a forward of the type:

def forward(self, x):
   if torch.allclose(x,self.known_input):
      return self.know_output
     return self.mlp(x)

Is there a simple way to keep that conditional computation inside the module, just yielding zero gradients when doing the backward pass through it?

Well, PyTorch will give you gradients if you have them during your calculation and none if you don’t. So if what you plan to do makes mathematical sense, you can just follow it. Of course, if you only have fixed inputs, your output will not require_grad, in which case you need to skip the backward. But then, there isn’t much learning going on in that case, anyway.

Best regards