Cuda CUDA Out of Memory error when accumulating loss inside a loop

My code runs into a Cuda OOM error when I am accumulating loss inside a loop. I understand that this due to the computational graph growing with each iteration. Given, that the inputs are images, this would be problematic. I have also added ‘del’ statements to manually free memory, but that still does not help and I run into the CUDA OOM issue within a few iterations inside the loop. Is there any way to get this running or do I need to alter my approach of accumulating loss inside a loop? I have attached the code below

Any of the functions not provided are simple functions for formatting or calculating the Cross entropy, which are straightforward and do not seem to be the cause of the CUDA OOM error

    def get_dist_lp(self, batch, format_batch_func):
        batch = format_batch_func(batch)
        obs =["obs_1"], batch["obs_2"]), dim=0)
        action =["action_1"], batch["action_2"]), dim=0)

        del batch
        obs =
        dist =

        del obs 

        if isinstance(dist, torch.distributions.Distribution):
            lp = dist.log_prob(action)
            assert dist.shape == action.shape
            # For independent gaussian with unit var, logprob reduces to MSE.
            lp = -torch.square(dist - action).sum(dim=-1)
        return lp

    # For now only works on th e train data, where we get a tuple of batch and perturbed_labels
    def get_avg_likelihood(self, train_dataloader, format_batch_func):

        tot_cpl_ll = 0.
        tot_correct_preds = 0
        data_size = 0.
        mini_batch_cnt = 0
        for i, (batch, labels_perturbed) in enumerate(train_dataloader):
            labels_perturbed =
            lp = self.get_dist_lp(batch, format_batch_func)

            # Compute the advantages.
            adv = self.alpha * lp
            segment_adv = adv.sum(dim=-1)

            adv1, adv2 = torch.chunk(segment_adv, 2, dim=0)
            # cpl_ll is of shape (num_dataset_perturbations,)
            cpl_ll, correct_preds = biased_bce_with_logits_perturbed(adv1, adv2, labels_perturbed, bias=self.contrastive_bias)
            tot_cpl_ll = tot_cpl_ll + cpl_ll
            tot_correct_preds += correct_preds
            data_size +=  labels_perturbed.shape[0]
            mini_batch_cnt += 1
            del batch, labels_perturbed

        avg_cpl_ll = tot_cpl_ll/data_size
        accuracy_all_ds = tot_correct_preds/data_size
        return avg_cpl_ll, accuracy_all_ds

    def CVaR(self, ll_tensor, cvar_alpha=0.95):
        sigma = ll_tensor.clone()
        cvar = torch.max(sigma - torch.mean(F.relu(sigma[:,None] - ll_tensor), dim=1)/(1-cvar_alpha))
        return cvar

    def train_step_perturbed(self, train_dataloader, format_batch_func, cvar_alpha=0.95):

        avg_ll_cpl, accuracy_all = self.get_avg_likelihood(train_dataloader, format_batch_func)
        avg_accuracy = accuracy_all.mean()
        cvar_ll = self.CVaR(avg_ll_cpl, cvar_alpha)
        loss = -cvar_ll
        del avg_ll_cpl, accuracy_all
        print(f"Loss: {loss.item():.4f} \n")
        return dict(cpl_loss=loss.item(), bc_loss=0., accuracy=avg_accuracy.item())

Deleting variables won’t help since you are still storing the reference in the accumulation and the computation graph thus won’t be freed. You could calculate the gradients for each loss separately and let Autograd accumulate the gradients or would need to reduce the overall memory usage if you want to keep accumulating the losses including their computation graphs.

Thanks for getting back so quickly :slight_smile:

The problem is that I compute the loss in the CVaR Method (later) after accumulating the “likelihood” over the entire dataset, which necessitated a pass through the entire dataset (i.e loop). So I can only call backward after the CVaR method. So I guess accumulating gradients would not work?

Also, If I do want to accumulate the gradients in the looping stage, I could need to compute a Jacobian at each iteration and sum it up. Then do a Jacobian vector product later on, to get the gradients of the CVaR loss wrt to model parameters. This method seems like it could save some memory ?