Why does constructor assignment cause "Trying to backward..." error?

I am implementing a network and I have noticed a behaviour that I would like to understand.
The following is my code; I removed all of the unrelated stuff.

class ParameterisedNormal():
    def __init__(self, mu, rho):
        self.mu = mu
        self.rho = rho
#         self.sigma = self.rho
        self.sigma = torch.log(1 + torch.exp(self.rho))
        
    def sample(self):
        epsilon = torch.normal(torch.zeros(self.sigma.shape), torch.ones(self.sigma.shape))
        return self.mu + self.sigma * epsilon        
    
    def log_prob(self, x):
        return torch.sum(-torch.log(self.sigma) 
                         -torch.log(torch.sqrt(torch.tensor(2*torch.pi))) 
                         -(x - self.mu)**2 / (2*self.sigma**2))
            
# Linear layer ignoring bias for now
class BayesLinear(nn.Module):
    def __init__(self, num_in, num_out):
        super().__init__()
        self.num_in = num_in
        self.num_out = num_out
        # mu and rho are the actual trainable weights of this layer
        self.mu = nn.Parameter(torch.zeros(num_out, num_in))
        self.rho = nn.Parameter(torch.ones(num_out, num_in))
        # approximation of the posterior q
        self.posterior = ParameterisedNormal(self.mu, self.rho)
        self.posterior_sample = None
        
    def forward(self, x):
        self.sample()
        x = F.linear(x,self.posterior_sample)
        return x
    
    def sample(self):
        self.posterior_sample = self.posterior.sample()
        
class BayesNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = BayesLinear(784, 800) 
        self.l2 = BayesLinear(800, 10)
        
    def forward(self, x):
        x = self.l1(x)
        x = F.relu(x)
        x = self.l2(x)
        x = F.log_softmax(x, dim=1)
        return x
  
  def monte_carlo_elbo(self, input, target):
        output = self(input)
        nll_data = F.nll_loss(output, target)
        elbo = nll_data
        return elbo
        
net = BayesNet()
optimizer = optim.Adam(net.parameters())

Training this net with the usual training loop gives RuntimeError: Trying to backward through the graph a second time…
However, setting self.sigma = torch.log(1 + torch.exp(self.rho)) to self.sigma = self.rho fixes the error. Why is this calculation involving torch.log and torch.exp causing autograd to backward through the graph a second time?

Your self.posterior_sample has an Autograd history and thus raises the error.
I’m not completely sure, but would assume you want to train self.mu and self.rho in both custom modules while the actual weight parameter is created via self.posterior.sample().
In your current code you are creating self.sigma via the trainable parameter self.rho. The usage of self.sigma will thus backpropagate the gradients to self.rho, too, and will keep the needed intermediate tensors needed for this calculation.
In the next iteration you are calling self.sample() again, which will reuse self.sigma and thus add the previous differentiable assignment (self.sigma = torch.log(1 + torch.exp(self.rho))) to the current computation graph again.
The following backward call will then fail because the computation graph (for the assignment) is already freed.
You could .detach() sigma and use it as a constant or you could recreate it in each iteration, if this fits your use case.

1 Like