Have to set "retain_graph = True" in .backward() to fit my model

Hello. I’m building a model and try to fit with autograd. It’s an auto-regressive HMM, I have only 1 loss function, and I called optimizer.zero_grad() before I do loss.backward() and still getting the error

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

It could only run with retain_graph set to True. It’s taking up a lot of RAM. Since I only have one loss, I think I don’t need it.

class ARHMM:

    def __init__(self, N, lags):

        self._N = torch.tensor(N)
        self._lags = torch.from_numpy(lags)
        self._max_lag = self._lags.max()

    def initialize(self):
        
        self._initial_matrix_raw = torch.rand((1, self._N), requires_grad= True)
        self._transition_matrix_raw = torch.rand((self._N, self._N), requires_grad= True)
        self._emission_parameter = torch.randn((self._N, len(self._lags) + 1), requires_grad = True)
        self._emission_sigma_raw = torch.randn((1, self._N), requires_grad = True)

        self._log_alpha = torch.zeros((self._T, self._N))

    def update_matrix(self):

        self._initial_matrix = self._initial_matrix_raw.clone().abs() / self._initial_matrix_raw.clone().abs().sum(dim = 1, keepdim = True)
        self._transition_matrix = self._transition_matrix_raw.clone().abs() / self._transition_matrix_raw.clone().abs().sum(dim = 1, keepdim = True)
        self._emission_sigma = self._emission_sigma_raw.clone().abs()

    def emission_logp(self, state, t):

        mu = self._emission_parameter[state, 0].clone()

        mu += torch.dot(self._y[t - self._lags], self._emission_parameter[state, 1:].to(torch.float64))

        x = self._y[t]

        sigma = self._emission_sigma[0, state].clone()

        logp = torch.tensor(-0.5) * (torch.tensor(2) * torch.tensor(np.pi) * sigma.square()).log() - (x - mu).square() / (torch.tensor(2) * sigma.square())

        return logp
    
    def log_sum_exp(self, acc):

        temp_max = acc.clone().max()

        return (acc.clone() - temp_max).exp().sum().log() + temp_max
        
    def calc_log_alpha(self):

        for i in torch.arange(self._N):

            self._log_alpha[self._max_lag, i] = self._initial_matrix[0, i].log() + self.emission_logp(i, self._max_lag)

        for t in torch.arange(self._max_lag + 1, self._T):

          for j in torch.arange(self._N):

            acc = self._log_alpha[t - 1] + self._transition_matrix[:, j].log() + self.emission_logp(j, t)

            self._log_alpha[t, j] = self.log_sum_exp(acc)

        self._negative_logp = torch.tensor(-1) * self.log_sum_exp(self._log_alpha[-1,:])

        print(self._negative_logp)
        print(self._emission_parameter)


    def fit(self, Y, iter = 500, lr = 0.001):

        self._y = torch.from_numpy(Y)
        self._T = self._y.shape[0]

        self.initialize()

        optimizer = torch.optim.Adam
        optimizer = optimizer([self._initial_matrix_raw,
                               self._transition_matrix_raw,
                               self._emission_parameter,
                               self._emission_sigma_raw],
                              lr=lr)

        for i in range(iter):

          self.update_matrix()
          self.calc_log_alpha()

          optimizer.zero_grad()
          self._negative_logp.backward()
          optimizer.step()

model = ARHMM(3, np.array([1,2,3]))
model.fit(np.array(obs), iter = 20)
model._emission_parameter

Calling zero_grad will not change anything as it’s only deleting the gradients in newer PyTorch versions and setting them to zero in older ones.
The error is raised by the actual backward call, which deletes the intermediate forward activations, which are needed for the gradient computation.
In a common use case you would call backward only once per iteration after the forward pass and the loss computation. However, in your training code it seems you are trying to call backward multiple tmes after a single forward pass.
In this case, you would need to keep the intermediates alive by using retain_graph=True and could call backward with retain_graph=False (default value) in the last step to clear the computation graph with its intermediates.

@ptrblck Thanks for your feedback. But I’m actually trying to call the backward() function once after the a single forward pass. The forward pass of my model is calc_log_alpha(). How should I modify the code to make it usable with retain_graph = False?

I guess calc_log_alpha is reusing tensors from previous iterations without detaching them from the old computation graph, which might cause the issue.
You would need to check that all inputs (not parameters) are either newly created or detached from the previous computation graph in calc_log_alpha.