Customizing DataParallel

I am running a reversible network and wish to use DataParallel on it, however my model has the following methods which I need:
forward
reverse
backward

Example of my network:

class reversenet(nn.Module):
def __init__(self, args):
    super().__init__()
    self.model = model

def forward(self, X):
    Y, Yo = self.model(X)
    return Y, Yo


def backward(self, Y, Yo, dY, get_optim):
    """ Compute gradients for the model parameters using the reversible property of the network to 
    re-calculate activation. Gradients are updated in place

    Arguments:
        YN {torch.Tensor} -- Final state from the forward pass
        Yo {torch.Tensor} -- Second last state from the forward pass
        dY {torch.Tensor} -- Derivative of YN
        get_optim {function} -- Function that returns an optimizer
    """
    
    return Y, Yo


def reverse(self, YN, Yo):
    """ Use the reversible property of the network to re-calculate the input to the forward pass.

    Arguments:
        YN {torch.Tensor} -- Final state from the forward pass
        Yo {torch.Tensor} -- Second last state from the forward pass

    Returns:
        torch.Tensor -- Recovery of initial state
        torch.Tensor -- Recovery of second state
    """

    Y = Yo
    Yo = YN


    return Y, Yo

But when I use the standard DataParallel function on this I lose access to the reverse routine.

I could create a second inverse model, which reuses the same parameters as the first model, and has its forward method being the reverse operation from the original model, but that seems rather convoluted to me and not all that pretty.

So I was wondering whether anyone has any experience customizing DataParallel to accept more methods?