Backward fails when I use a composition for torch.nn.Parameter

Hello,
Initially, I faced the following issue:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

I figured out that the problem raised when I use torch.nn.Parameter in my model and do some transformations with it. It is easy to understand just by looking at the toy example:

import torch
import torch.nn as nn

# Define our 'model'
class Simple_func(nn.Module):
    def __init__(self,):
        super(Simple_func, self).__init__()
        self.a = nn.Parameter((torch.randn(1) + 5) ** 2)
#         self.b = torch.log(self.a)
    def forward(self, x):
        return x + self.a
#         return x + self.b

model = Simple_func()
print([p for p in model.parameters()])
# here I have the only parameter, it's value is (torch.randn(1) + 5) ** 2

# And here we have 'training process.'

optim = torch.optim.Adamax(model.parameters())
data = torch.arange(10, dtype=torch.float32) - 5
for i, d_point in enumerate(data):
    loss = (model(d_point) - d_point)**2
    loss.backward()
    optim.step()
    optim.zero_grad()

The code above works fine.

But let’s say for numerical stability I want to optimise such a Parameter which will be inside a logarithm function, so I change only Simple_func in the code above:

import torch
import torch.nn as nn

# Define our 'model'
class Simple_func(nn.Module):
    def __init__(self,):
        super(Simple_func, self).__init__()
        self.a = nn.Parameter((torch.randn(1) + 5) ** 2)
        self.b = torch.log(self.a) . # <-------- UNCOMMENTED
    def forward(self, x):
#         return x + self.a # <-------- COMMENTED
        return x + self.b # <-------- UNCOMMENTED

model = Simple_func()
print([p for p in model.parameters()])
# here I have the only parameter, it's value is (torch.randn(1) + 5) ** 2

# And here we have 'training process'

optim = torch.optim.Adamax(model.parameters())
data = torch.arange(10, dtype=torch.float32) - 5
for i, d_point in enumerate(data):
    loss = (model(d_point) - d_point)**2
    loss.backward()
    optim.step()
    optim.zero_grad()

The code above does not work anymore. It throws the error at the second backward:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

My questions are:

  1. What is the problem which causes such an error?
  2. What is the correct way to use nn.Parameter to encapsulate it in other functions?

I understand that I could use torch.tensor notation here with requires_grad=True, but as far as I understand I will have to manually add such tensors to torch.optim to optimise over them.

  1. The problem is that you only compute self.b once - in the __init__. Then on the first backward, it’ll work. But the next time around, the edge in the computation graph connecting self.b to self.a will be gone.
  2. The most obvious way to get this to work is to move the line self.b = torch.log(self.a) from the __init__ to the top of the forwrd.
  3. This comes up every once in a while, but it’s not important enough/the workaround is good enough so it is not a priority.

Best regards

Thomas

1 Like

I got it, thank you!