Hello, I am working on a problem where I have some non-differentiable function in the middle of my neural network. I would like to use this non-differentiable function in the forward pass only. However, I don’t want to include it in the backward pass. Can i do anything in PyTorch to disconnect it from the computation graph in the backward pass, but keep it in the forward pass? For the sake of an example, consider i have a network that consists of a CNN–>NON DIFFRENTIABLE FUNCTION --> CNN, and i want to backpropagate the outputs of the second CNN all the way to the inputs of the first CNN, but that cannot happen because the Non-Differentiable function is blocking in the way, and wont allow gradients to flow. So how can i remove it from the backward pass but still include it in the forward pass?
It does not make sense to do this. Consider the following chain of functions: y=C2(f(C1(x)))
:
x --> u=C1(x) --> v=f(u) --> y=C2(v)
Then computing the gradients of C1using the chain-rule relies on back-propagating the gradients of the output y
all the way back to u
. But if we cannot compute the gradients of function f
, this chain is broken.