I am re-implementing the Invertible Residual Networks architecture. This model has, beyond the forward pass, also an `inverse`

function:

```
class iResNetBlock(nn.Module):
def __init__(self, input_size, hidden_size):
self.bottleneck = nn.Sequential(
LinearContraction(input_size, hidden_size),
LinearContraction(hidden_size, input_size),
nn.ReLU(),
)
def forward(self, x):
return x + self.bottleneck(x)
def inverse(self, y):
x = y.clone()
while not converged:
# fixed point iteration
x = y - self.bottleneck(x)
return x
```

I want to add a custom backward pass to the `inverse`

function. Since it is a fixed point iteration, one can make use of the implicit function theorem to avoid unrolling of the loop, and instead compute the gradient by solving a linear system. This is for example done in the Deep Equilibrium Models architecture.

How do I register my custom backwards pass for this function? I would have hoped that there is some decorator (this is for example possible in Jax) or other simple way of telling PyTorch to use a certain function as the backward of another function.

```
class iResNetBlock(nn.Module):
....
def inverse_backwards(self, grad_output)
# How to tell pytorch to use this as the backward for the .inverse?
```

In particular, I want that, when I later define some loss such as `r = loss(y, model.inverse(other_model(model(x))))`

, that `r.backwards()`

correctly uses my custom gradient for the inverse call.

Ideally the solution should be `torchscript`

-compatible. A colleague suggested putting a ` inverse=False`

switch in the forward pass, but (a) I am not sure if this key-val would get passed to the backward pass and (b) this would force me to also implement a `backward`

for the `forward`

function, which I do not want to do. Another Idea came up to create an “Inverse” module, but it seems the parameter sharing between modules is not possible.