I am running a reversible network and wish to use DataParallel on it, however my model has the following methods which I need:

forward

reverse

backward

Example of my network:

```
class reversenet(nn.Module):
def __init__(self, args):
super().__init__()
self.model = model
def forward(self, X):
Y, Yo = self.model(X)
return Y, Yo
def backward(self, Y, Yo, dY, get_optim):
""" Compute gradients for the model parameters using the reversible property of the network to
re-calculate activation. Gradients are updated in place
Arguments:
YN {torch.Tensor} -- Final state from the forward pass
Yo {torch.Tensor} -- Second last state from the forward pass
dY {torch.Tensor} -- Derivative of YN
get_optim {function} -- Function that returns an optimizer
"""
return Y, Yo
def reverse(self, YN, Yo):
""" Use the reversible property of the network to re-calculate the input to the forward pass.
Arguments:
YN {torch.Tensor} -- Final state from the forward pass
Yo {torch.Tensor} -- Second last state from the forward pass
Returns:
torch.Tensor -- Recovery of initial state
torch.Tensor -- Recovery of second state
"""
Y = Yo
Yo = YN
return Y, Yo
```

But when I use the standard DataParallel function on this I lose access to the reverse routine.

I could create a second inverse model, which reuses the same parameters as the first model, and has its forward method being the reverse operation from the original model, but that seems rather convoluted to me and not all that pretty.

So I was wondering whether anyone has any experience customizing DataParallel to accept more methods?