I am currently working on a setup where the output of my network is first modified by my own module before computing the loss and backpropagating this, as follows:
self.Layer1 = Layer()
self.Layer2 = Layer()
self.Mod = Own_module()
self.LossLayer = Loss()
def forward(self, x):
y1 = self.Layer1(x)
y2 = self.Layer2(y1)
self.output = self.Mod(y2)
loss = self.LossLayer(self.output)
This works, however,
Own_mudule() is an extensive simulation with lots of multiplications and loops (operations are completely fixed). Because of all the computational graphs that need to be saved during this simulation, memory issues occur. I tried solving this using:
However, this still results in memory issues. And when I use
detach() the backpropagation gets blocked completely.
Is there a way I can not save the computational graphs during the
Own_module() and essentially bypass it during backpropagation?
Thanks a lot!
Yes, but …
You won’t want to bypass your
Own_module backpropagation entirely.
Own_module were simply
return -input, that is, it just flipped
the sign of (the gradient of) the loss function? Then you would be training
your model to maximize the loss function, your predictions would become
increasing poor, and your parameters would likely run off to infinity.
Therefore, you need to backpropagate back through
at least some crude approximation to
Own_module's own gradient, at a
minimum getting the sign of its gradient correct at least most of the time.
I would suggest that you package
Own_module not as a
as a torch.autograd.Function. Your custom
method would implement your “extensive simulation” without saving
the computational graph (using, e.g.,
.detach() or protected by a
with torch.no_grad(): block), and the computation graph would get
“glued back together” because you would also implement your
.backward() method. The more accurately
the actual gradient of
.forward(), the better your model will likely train.
As an aside, your
Module doesn’t need to (and, to avoid
confusion, shouldn’t) implement a
backward() method – in a more
standard design, the optimization code in your
would be moved outside of
NeuralNet into your training loop after
the application fo
NeuralNet to your input data.