How can i exclude a non-diffrentiable function in the backward pass?

Hello, I am working on a problem where I have some non-differentiable function in the middle of my neural network. I would like to use this non-differentiable function in the forward pass only. However, I don’t want to include it in the backward pass. Can i do anything in PyTorch to disconnect it from the computation graph in the backward pass, but keep it in the forward pass? For the sake of an example, consider i have a network that consists of a CNN–>NON DIFFRENTIABLE FUNCTION --> CNN, and i want to backpropagate the outputs of the second CNN all the way to the inputs of the first CNN, but that cannot happen because the Non-Differentiable function is blocking in the way, and wont allow gradients to flow. So how can i remove it from the backward pass but still include it in the forward pass?

It does not make sense to do this. Consider the following chain of functions: y=C2(f(C1(x))):

x --> u=C1(x) --> v=f(u) --> y=C2(v)

Then computing the gradients of C1using the chain-rule relies on back-propagating the gradients of the output y all the way back to u. But if we cannot compute the gradients of function f , this chain is broken.