Strange Appearance of CUDA out of memory error

I want to add a metric to my code called Expected Calibration Erro (ECE) to measure how well-calibrated my model is (in a continual set up). I am using a library called torchmtrics that provides frequently used deep learning metrics using torch. In my repo I have multiple methods that have been implemented and strangely, when I add the ECE module to the train loop of one of them, I get an error. This does not happen when I added the same module to another method of the code. This is my code:

First, this is the repository that I am using: GitHub - clovaai/rainbow-memory: Official pytorch implementation of Rainbow Memory (CVPR 2021)
and I am facing problems when I’m using the method ewc (/Methds/Regularization).

I imported the package first:

from torchmetrics.classification import MulticlassCalibrationError

then in the main.py, inside the loop over epochs, added:

MCCE = MulticlassCalibrationError(num_classes = method.classifier_size, n_bins=method.classifier_size, norm="l1")

and

task_acc, eval_dict = method.train(
                cur_iter=cur_iter,
                n_epoch=args.n_epoch,
                batch_size=args.batchsize,
                n_worker=args.n_worker,
                writer=writer,
                MCCE=MCCE
            )

inside regularization.py and def train() function,

  train_loss, train_acc, Train_ECE = self._train(
                train_loader=train_loader,
                optimizer=self.optimizer,
                epoch=epoch,
                total_epochs=n_epoch,
                MCCE=MCCE
            )

            MCCE.reset()

and finally, _train() is:

    def _train(self, train_loader, optimizer, epoch, total_epochs, MCCE):
        total_loss, correct, num_data = 0.0, 0.0, 0.0
        self.model.train()
        for i, data in enumerate(train_loader):
            x = data["image"]
            y = data["label"]
            x = x.to(self.device)
            y = y.to(self.device)

            optimizer.zero_grad()

            do_cutmix = self.cutmix and np.random.rand(1) < 0.5
            if do_cutmix:
                x, labels_a, labels_b, lam = cutmix_data(x=x, y=y, alpha=1.0)
                logit = self.model(x)
                loss = lam * self.criterion(logit, labels_a) + (
                    1 - lam
                ) * self.criterion(logit, labels_b)
            else:
                logit = self.model(x)
                loss = self.criterion(logit, y)

            reg_loss = self.regularization_loss()

            loss += reg_loss
            loss.backward(retain_graph=True)
            optimizer.step()

            ece = MCCE(logit, y)
            _, preds = logit.topk(self.topk, 1, True, True)
            total_loss += loss.item()
            correct += torch.sum(preds == y.unsqueeze(1)).item()
            num_data += y.size(0)

        n_batches = len(train_loader)
        train_ECE = MCCE.compute().item()
        return total_loss / n_batches, correct / num_data, train_ECE

and it is inside this function that I get the error and noticed the strange behavior. This is the error:

Traceback (most recent call last):
  File "/visinf/home/shamidi/Projects/rainbow-memory/main.py", line 222, in <module>
    main()
  File "/visinf/home/shamidi/Projects/rainbow-memory/main.py", line 147, in main
    task_acc, eval_dict = method.train(
  File "/visinf/home/shamidi/Projects/rainbow-memory/methods/regularization.py", line 103, in train
    train_loss, train_acc = self._train(
  File "/visinf/home/shamidi/Projects/rainbow-memory/methods/regularization.py", line 181, in _train
    logit = self.model(x)
  File "/visinf/home/shamidi/anaconda3_new/envs/first_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/visinf/home/shamidi/Projects/rainbow-memory/models/cifar.py", line 212, in forward
    out = self.group4(out)
  File "/visinf/home/shamidi/anaconda3_new/envs/first_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/visinf/home/shamidi/Projects/rainbow-memory/models/cifar.py", line 117, in forward
    return self.blocks(x)
  File "/visinf/home/shamidi/anaconda3_new/envs/first_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/visinf/home/shamidi/anaconda3_new/envs/first_env/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/visinf/home/shamidi/anaconda3_new/envs/first_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/visinf/home/shamidi/Projects/rainbow-memory/models/cifar.py", line 34, in forward
    _out = self.conv1(x)
  File "/visinf/home/shamidi/anaconda3_new/envs/first_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/visinf/home/shamidi/Projects/rainbow-memory/models/layers.py", line 56, in forward
    return self.block.forward(input)
  File "/visinf/home/shamidi/anaconda3_new/envs/first_env/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/visinf/home/shamidi/anaconda3_new/envs/first_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/visinf/home/shamidi/anaconda3_new/envs/first_env/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 168, in forward
    return F.batch_norm(
  File "/visinf/home/shamidi/anaconda3_new/envs/first_env/lib/python3.9/site-packages/torch/nn/functional.py", line 2438, in batch_norm
    return torch.batch_norm(
RuntimeError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 11.91 GiB total capacity; 11.17 GiB already allocated; 2.31 MiB free; 11.18 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

and when I remove the ece = MCCE(logit, y), the error disappears. I have no idea how this line and the memory error are related! I appreciate if anyone has a hint or experienced something similar.

I don’t see how ece influences the memory usage, but you are increasing the memory usage in each iteration by keeping the computation graph alive:

loss += reg_loss
loss.backward(retain_graph=True)

Using retain_graph=True will disallow PyTorch to clear the intermediate forward activations used to compute the gradients and you are also appending the computation graph keeping all activations from previous iterations alive.
Could you explain why you are accumulating the losses inplace?