What(): Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed)

I was facing a problem of getting negative KLs or very large KLs times to times during training of my model. So I use the cooper library to impose some constraints on the KL terms of the components of mixture of experts model. Here is my constraint class

import cooper
class PenalizedKLConstraint(cooper.ConstrainedMinimizationProblem):
    """
    Class for KL-penalized constrains.
    """
    def __init__(self, module: nn.Module, l1_lambda: float= 0.0001, mean_constraint:float = 2.5, device: torch.types.Device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
        self.mean_constraint = mean_constraint
        self._module = module
        self.l1_lambda = l1_lambda
        self.device = device
        self._module.to(self.device)
        super(PenalizedKLConstraint, self).__init__(is_constrained=True)

    def compute_losses(
        self, 
        model: nn.Module,
        context: torch.Tensor, 
        another_context:torch.Tensor,
        precision: torch.Tensor,
        weights: torch.Tensor,
    ): 
        model.to(self.device)
        with torch.no_grad():
             samples = model.sample(context.to(self.device))
        losses = - torch.squeeze(self._module(torch.cat([context.to(self.device), samples], dim=-1), inTrain=False))
        kls = model.kls_other_chol_inv(context.to(self.device),  another_context.to(self.device), precision.to(self.device))
        #Adding  L1 regularization to prevent gradient explosion
        l1_norm = torch.mul(self.l1_lambda , sum(p.abs().sum() for p in model.parameters())).to(self.device)
                
        loss = torch.mean(weights * (losses + kls)).to(self.device) +  l1_norm + model._regularizer
        return loss, kls, l1_norm

    def closure(
        self,
        inputs: torch.Tensor,
        targets: torch.Tensor,
        p:torch.Tensor,
        w:torch.Tensor,
        model: nn.Module,
    ) -> cooper.CMPState:
        """
        Computes CMP state for the given model and inputs.
        """
        # Compute loss and regularization statistics
        loss, kls, l1_norm = self.compute_losses(model, inputs, targets, p, w)
        misc = {"l1_norm_vlue": l1_norm, "kl_value": kls}

        # Entries of kls >= 0 (equiv. -kls <= 0) & kls - costraint <= 0
        ineq_defect = torch.stack([-kls, kls - self.mean_constraint])

        # Store model output and other 'miscellaneous' objects in misc dict
        state = cooper.CMPState(loss=loss, ineq_defect=ineq_defect , misc=misc)

        return state

Then a snippet of my main class which is very long, here I post where I initialized the model with constraint

class MixtureEIM:

    @staticmethod
    def get_default_config() -> ConfigDict:
        c = ConfigDict(
            mean_constraint=10,
            checkpoint_dir=os. getcwd()+ "/EIM_out_imgs/"
        )
        c.finalize_adding()
        return c

    def __init__(self, config: ConfigDict, context_dim: int, sample_dim: int, seed: int = 0):
        self.checkpoint_dir= self.c.checkpoint_dir
        self._mean_constraint= self.c.mean_constraint

        self._cmp = PenalizedKLConstraint(self._dre, l1_lambda = 0.0001, mean_constraint=self._mean_constraint, device=device)
        self._formulation= cooper.LagrangianFormulation(self._cmp)
        self._model.to(device)

    def _components_train_step(self, importance_weights, old_means, old_chol_precisions, save_model=False):

        for i in range(self._model.num_components):
            dataset = torch.utils.data.TensorDataset(self._train_contexts, importance_weights[:, i], old_means, old_chol_precisions)

            loader = torch.utils.data.DataLoader( dataset, shuffle = True, batch_size=self.c.components_batch_size)
            
            
            for batch_idx, (context_batch, iw_batch, old_means_batch, old_chol_precisions_batch) in enumerate(loader):
                
                iw_batch = iw_batch / torch.sum(iw_batch)
                
                primal_optimizer = cooper.optim.ExtraSGD(self._model.components[i].trainable_variables, lr=self.c.components_learning_rate, momentum=0.7)
                dual_optimizer = cooper.optim.partial_optimizer(cooper.optim.ExtraSGD, lr=1e-3, momentum=0.7)
                # Wrap the formulation and both optimizers inside a ConstrainedOptimizer
                coop = cooper.ConstrainedOptimizer(formulation=self._formulation,
                                                   primal_optimizer=primal_optimizer,
                                                   dual_optimizer=dual_optimizer,
                                                   )
                    
                coop.zero_grad()
                lagrangian = self._formulation.composite_objective(self._cmp.closure, context_batch, old_means_batch[:, i], old_chol_precisions_batch[:, i], iw_batch, self._model.components[i])
                self._formulation.custom_backward(lagrangian)
                coop.step(self._cmp.closure, context_batch, old_means_batch[:, i], old_chol_precisions_batch[:, i], iw_batch, self._model.components[i])
                loss = self._cmp.closure(model, context_batch, old_means_batch[:, i], old_chol_precisions_batch[:, i], iw_batch, self._model.components[i]).loss

When I ran the above code on a remote machine using slurm script, I got this error message:

  what():  Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
Exception raised from unpack at /tmp/coulombc/pytorch_build_2021-11-09_14-57-01/avx2/python3.8/pytorch/torch/csrc/autograd/saved_variable.cpp:122 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x55 (0x2ae6de115905 in /home/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0xd6 (0x2ae6de0f72a9 in /home/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #2: torch::autograd::SavedVariable::unpack(std::shared_ptr<torch::autograd::Node>) const + 0x13d2 (0x2ae6bb4e6bb2 in /home/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)

/var/spool/slurmd/job32471479/slurm_script: line 23: 84798 Aborted                 (core dumped) python EIMn.py

I am not sure where the computational graph got repeated which has caused me to get this error?

I will appreciate if someone can help. Thanks.
P.S. I can not use retain_graph=True for the backward inside the PenalizedKLConstraint since it is not defined.