Error Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed)

I am encountering this error while implementing a paper. This is the code that I am using:

        for epoch in range(num_epochs):
            for _, (x_batch, y_batch) in enumerate(data_loader):
                optimizer.zero_grad()

                num_samples = x_batch.shape[0]
                v_batch = self.G(x_batch)  # self.G, H, F, T are torch.nn.Modules
                z_batch = self.H(v_batch)
                pred_batch = self.F(v_batch)

                p = torch.matmul(z_batch, self.c) / self.tau
                p = torch.nn.functional.softmax(p, dim=1)

                q = self.lambda_ * pred_batch + (1 - self.lambda_) * p
                q_hat = torch.zeros_like(q)
                for i in range(x_batch.shape[0]):
                    for k in range(self.num_classes):
                        indices = (y_batch[i, :] == 1).nonzero()
                        if sum(indices) != 0:
                            q_hat[indices, k] = q[indices, k] / sum(q[indices, k])

                self.update_class_prototype(z_batch, pred_batch, y_batch)

                r = self.T(y_batch)
                r = torch.mean(r, dim=0).reshape(
                    self.r.shape)  # take the average for all samples in batch. not mentioned in the paper
                self.r = self.eta * self.r + (1 - self.eta) * r

                loss_class = torch.sum(-torch.mul(q_hat, torch.log(pred_batch))) / num_samples
                if not torch.isfinite(loss_class):
                    raise (RuntimeError(f"classification loss is {loss_class.item()}"))
                loss_ml = torch.sum(-torch.mul(pred_batch, torch.log(p))) / num_samples
                if not torch.isfinite(loss_ml):
                    raise (RuntimeError(f"mutual learning loss is {loss_ml.item()}"))
                y_tilde_batch = torch.nn.functional.normalize(y_batch, dim=1)
                loss_con = torch.nn.functional.kl_div(torch.log(torch.matmul(pred_batch, self.r) + torch.tensor(1e-10)),
                                                      y_tilde_batch) + \
                           torch.nn.functional.kl_div(torch.log(y_tilde_batch + torch.tensor(1e-10)),
                                                      torch.matmul(pred_batch, self.r))
                if not torch.isfinite(loss_con):
                    raise (RuntimeError(f"consistency loss is {loss_con.item()}"))
                loss = loss_class + self.alpha * loss_ml + self.beta * loss_con

                loss.backward()
                optimizer.step()

I am not clear where I am doing the backward for a second time. Even if I change loss.backward() to loss_class.backward(), which is just one loss, I get this error. I went through the full error message (pasted below), and found that the error occurs at self.r = self.eta * self.r + (1 - self.eta) * r.

C:\ProgramData\Anaconda3\envs\mlpll\python.exe "C:\Program Files\JetBrains\PyCharm 2021.3.2\plugins\python\helpers\pydev\pydevd.py" --cmd-line --multiprocess --qt-support=auto --client 127.0.0.1 --port 53364 --file "C:/Work/mlpll/main.py"
Connected to pydev debugger (build 221.5787.24)
C:\ProgramData\Anaconda3\envs\mlpll\lib\site-packages\torch\nn\functional.py:2919: UserWarning: reduction: 'mean' divides the total loss by both the batch size and the support size.'batchmean' divides only by the batch size, and aligns with the KL div math definition.'mean' will be changed to behave the same as 'batchmean' in the next major release.
  warnings.warn(
C:\ProgramData\Anaconda3\envs\mlpll\lib\site-packages\torch\autograd\__init__.py:200: UserWarning: Error detected in MulBackward0. Traceback of forward call that caused the error:
  File "C:\Program Files\JetBrains\PyCharm 2021.3.2\plugins\python\helpers\pydev\pydevd.py", line 2181, in <module>
    main()
  File "C:\Program Files\JetBrains\PyCharm 2021.3.2\plugins\python\helpers\pydev\pydevd.py", line 2172, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "C:\Program Files\JetBrains\PyCharm 2021.3.2\plugins\python\helpers\pydev\pydevd.py", line 1484, in run
    return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
  File "C:\Program Files\JetBrains\PyCharm 2021.3.2\plugins\python\helpers\pydev\pydevd.py", line 1491, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "C:\Program Files\JetBrains\PyCharm 2021.3.2\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "C:/Work/mlpll/main.py", line 18, in <module>
    model.train_model(data_loader=train_data_loader, num_epochs=1000, optimizer=optimizer)
  File "C:\Work\mlpll\model_management.py", line 122, in train_model
    self.r = self.eta * self.r + (1 - self.eta) * r
 (Triggered internally at C:\cb\pytorch_1000000000000\work\torch\csrc\autograd\python_anomaly_mode.cpp:119.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "C:\Program Files\JetBrains\PyCharm 2021.3.2\plugins\python\helpers\pydev\pydevd.py", line 1491, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "C:\Program Files\JetBrains\PyCharm 2021.3.2\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "C:/Work/mlpll/main.py", line 18, in <module>
    model.train_model(data_loader=train_data_loader, num_epochs=1000, optimizer=optimizer)
  File "C:\Work\mlpll\model_management.py", line 140, in train_model
    loss.backward()
  File "C:\ProgramData\Anaconda3\envs\mlpll\lib\site-packages\torch\_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "C:\ProgramData\Anaconda3\envs\mlpll\lib\site-packages\torch\autograd\__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Process finished with exit code 1

I am not sure if retain_graph=True is a solution here, since I am not calling the .backward() function more than once. How do I handle this error? Thanks in advance for your help.

Most likely not unless you explicitly want to keep the computation graph alive and can explain why.

In this line of code you are adding self.r to itself including the already attached computation graph.
Check if calling self.r = ... + self.r.detach() would help.

Thanks for the suggestion. However, self.r = self.eta * self.r + (1 - self.eta) * r did not help. Instead, I used a temporary variable, which resolved the issue. Thanks again for your help.