Bypass module during backpropagation

Hi everyone,

I am currently working on a setup where the output of my network is first modified by my own module before computing the loss and backpropagating this, as follows:

class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()

            self.Layer1 = Layer()
            self.Layer2 = Layer()

            self.Mod = Own_module()

            self.LossLayer = Loss()

    def forward(self, x):
        y1 = self.Layer1(x)
        y2 = self.Layer2(y1)

        self.output = self.Mod(y2)

    def backward(self):
        loss =  self.LossLayer(self.output)

This works, however, Own_mudule() is an extensive simulation with lots of multiplications and loops (operations are completely fixed). Because of all the computational graphs that need to be saved during this simulation, memory issues occur. I tried solving this using:


However, this still results in memory issues. And when I use detach() the backpropagation gets blocked completely.

Is there a way I can not save the computational graphs during the Own_module() and essentially bypass it during backpropagation?

Thanks a lot!

Hi Luuk!

Yes, but …

You won’t want to bypass your Own_module backpropagation entirely.

Suppose Own_module were simply return -input, that is, it just flipped
the sign of (the gradient of) the loss function? Then you would be training
your model to maximize the loss function, your predictions would become
increasing poor, and your parameters would likely run off to infinity.

Therefore, you need to backpropagate back through Own_module with
at least some crude approximation to Own_module's own gradient, at a
minimum getting the sign of its gradient correct at least most of the time.

I would suggest that you package Own_module not as a Module, but
as a torch.autograd.Function. Your custom Function's forward()
method would implement your “extensive simulation” without saving
the computational graph (using, e.g., .detach() or protected by a
with torch.no_grad(): block), and the computation graph would get
“glued back together” because you would also implement your Function's
.backward() method. The more accurately .backward() approximates
the actual gradient of .forward(), the better your model will likely train.

As an aside, your NeuralNet Module doesn’t need to (and, to avoid
confusion, shouldn’t) implement a backward() method – in a more
standard design, the optimization code in your .backward() method
would be moved outside of NeuralNet into your training loop after
the application fo NeuralNet to your input data.


K. Frank

1 Like