CUDA OOM error when no annotations

Hello,

I am using a network for object detection (retinanet) and after some iterations (not always the same number) I end up with an OOM error.

Here is the Traceback:

Traceback (most recent call last):

  File "train_uda.py", line 427, in <module>

    main()

  File "train_uda.py", line 170, in main

    train_loss = train(dataloader_train, dataloader_uda, retinanet, optimizer, writer, epoch_num, train_hist, start_uda)

  File "train_uda.py", line 250, in train

    uda_loss = uda(model, dataloader_uda, uda_iter, iter_u, writer, (iter_num + len(dataloader_train) * epoch))

  File "train_uda.py", line 400, in uda

    classification_loss, regression_loss = model([unlabeled['augment'].cuda().float(), ground_truth])

  File "/home/rvandeghen/miniconda3/envs/tfe/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__

    result = self.forward(*input, **kwargs)

  File "/home/rvandeghen/miniconda3/envs/tfe/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward

    return self.module(*inputs[0], **kwargs[0])

  File "/home/rvandeghen/miniconda3/envs/tfe/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__

    result = self.forward(*input, **kwargs)

  File "/home/rvandeghen/TFE/Codes/retinanet/retinanet/model.py", line 268, in forward

    return self.focalLoss(classification, regression, anchors, annotations)

  File "/home/rvandeghen/miniconda3/envs/tfe/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__

    result = self.forward(*input, **kwargs)

  File "/home/rvandeghen/TFE/Codes/retinanet/retinanet/losses.py", line 91, in forward

    IoU = calc_iou(anchors[0, :, :], bbox_annotation[:, :4]) # num_anchors x num_annotations

  File "/home/rvandeghen/TFE/Codes/retinanet/retinanet/losses.py", line 16, in calc_iou

    ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih

RuntimeError: CUDA out of memory. Tried to allocate 1.76 GiB (GPU 0; 10.76 GiB total capacity; 8.95 GiB already allocated; 1017.56 MiB free; 8.98 GiB reserved in total by PyTorch)

And the code where it happens:

def calc_iou(a, b):

    area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])

    iw = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 0])

    ih = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 1])

    iw = torch.clamp(iw, min=0)

    ih = torch.clamp(ih, min=0)

    ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih

    ua = torch.clamp(ua, min=1e-8)

    intersection = iw * ih

    IoU = intersection / ua

    return IoU

class FocalLoss(nn.Module):

    #def __init__(self):

    def forward(self, classifications, regressions, anchors, annotations):

        alpha = 0.25

        gamma = 2.0

        batch_size = classifications.shape[0]

        classification_losses = []

        regression_losses = []

        anchor = anchors[0, :, :]

        anchor_widths  = anchor[:, 2] - anchor[:, 0]

        anchor_heights = anchor[:, 3] - anchor[:, 1]

        anchor_ctr_x   = anchor[:, 0] + 0.5 * anchor_widths

        anchor_ctr_y   = anchor[:, 1] + 0.5 * anchor_heights

        for j in range(batch_size):

            classification = classifications[j, :, :]

            regression = regressions[j, :, :]

            bbox_annotation = annotations[j, :, :]

            bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]

            # classification in [1e-4, 1-1e-4]

            classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)

            # if no annotation

            if bbox_annotation.shape[0] == 0:

                if torch.cuda.is_available():

                    alpha_factor = torch.ones(classification.shape).cuda() * alpha

                    alpha_factor = 1. - alpha_factor

                    focal_weight = classification

                    focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

                    bce = -(torch.log(1.0 - classification))

                    cls_loss = focal_weight * bce

                    regression_losses.append(torch.tensor(0).float().cuda())

                    classification_losses.append(cls_loss.sum())

                    del regression

                else:

                    alpha_factor = torch.ones(classification.shape) * alpha

                    alpha_factor = 1. - alpha_factor

                    focal_weight = classification

                    focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

                    bce = -(torch.log(1.0 - classification))

                    cls_loss = focal_weight * bce

                    regression_losses.append(torch.tensor(0).float())

                    classification_losses.append(cls_loss.sum())

                    del regression

                continue

            IoU = calc_iou(anchors[0, :, :], bbox_annotation[:, :4]) # num_anchors x num_annotations

            IoU_max, IoU_argmax = torch.max(IoU, dim=1) # num_anchors x 1

            #import pdb

            #pdb.set_trace()

            # compute the loss for classification

            targets = torch.ones(classification.shape) * -1

            

            if torch.cuda.is_available():

                targets = targets.cuda()

            # targets < 0.4 set to 0

            targets[torch.lt(IoU_max, 0.4), :] = 0

            # print(targets.shape[0] - (targets==0).sum())

            # indices = IoU > 0.5

            positive_indices = torch.ge(IoU_max, 0.5)

            num_positive_anchors = positive_indices.sum()

            assigned_annotations = bbox_annotation[IoU_argmax, :]

            targets[positive_indices, :] = 0

            targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1

            if torch.cuda.is_available():

                alpha_factor = torch.ones(targets.shape).cuda() * alpha

            else:

                alpha_factor = torch.ones(targets.shape) * alpha

            alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)

            focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)

            focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

            bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))

            # cls_loss = focal_weight * torch.pow(bce, gamma)

            cls_loss = focal_weight * bce

            if torch.cuda.is_available():

                cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda())

            else:

                cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape))

            classification_losses.append(cls_loss.sum()/torch.clamp(num_positive_anchors.float(), min=1.0))

            # compute the loss for regression

            if positive_indices.sum() > 0:

                assigned_annotations = assigned_annotations[positive_indices, :]

                anchor_widths_pi = anchor_widths[positive_indices]

                anchor_heights_pi = anchor_heights[positive_indices]

                anchor_ctr_x_pi = anchor_ctr_x[positive_indices]

                anchor_ctr_y_pi = anchor_ctr_y[positive_indices]

                gt_widths  = assigned_annotations[:, 2] - assigned_annotations[:, 0]

                gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]

                gt_ctr_x   = assigned_annotations[:, 0] + 0.5 * gt_widths

                gt_ctr_y   = assigned_annotations[:, 1] + 0.5 * gt_heights

                # clip widths to 1

                gt_widths  = torch.clamp(gt_widths, min=1)

                gt_heights = torch.clamp(gt_heights, min=1)

                targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi

                targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi

                targets_dw = torch.log(gt_widths / anchor_widths_pi)

                targets_dh = torch.log(gt_heights / anchor_heights_pi)

                targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh))

                targets = targets.t()

                if torch.cuda.is_available():

                    targets = targets/torch.Tensor([[0.1, 0.1, 0.2, 0.2]]).cuda()

                else:

                    targets = targets/torch.Tensor([[0.1, 0.1, 0.2, 0.2]])

                negative_indices = 1 + (~positive_indices)

                regression_diff = torch.abs(targets - regression[positive_indices, :])

                regression_loss = torch.where(

                    torch.le(regression_diff, 1.0 / 9.0),

                    0.5 * 9.0 * torch.pow(regression_diff, 2),

                    regression_diff - 0.5 / 9.0

                )

                regression_losses.append(regression_loss.mean())

            else:

                if torch.cuda.is_available():

                    regression_losses.append(torch.tensor(0).float().cuda())

                else:

                    regression_losses.append(torch.tensor(0).float())

        return torch.stack(classification_losses).mean(dim=0, keepdim=True), torch.stack(regression_losses).mean(dim=0, keepdim=True)

I tried to solve it with many things but it never works.

Thank you!

PS: this only happens when there are some images without annotations in my dataset.

Here is a code which illustrates my training:

for iter1, data1 in enumerate(dataloader1): # batch size of dataloader1 is small because no DataParallel
   optimizer.zero_grad()
   class_loss1, reg_loss1 = model(input1, output1) # model is retinanet
   total1 = class_loss1 + reg_loss1
   total1.backward()
   writer.add_scalars('Loss', {'Loss1': total1.item()}, iter1)
   del total1
   ...
   acc_loss = 0
   for iter2, data2 in enumerate(dataloader2): size dl2 >> dl1 -> batch size = 1 and we accumulate the gradient manually
       with torch.no_grad():
          scores, classification, predicted = model(input2) # inference
       class_loss2, reg_loss2 = model(input2, output2) # model is retinanet
       total2 = class_loss2 + reg_loss2
       total2.backward()
       acc_loss += total2.detach()
       del total2

   writer.add_scalars('Loss', {'Loss2': acc_loss}, iter1)
   optimizer.step()

Such that there are 2 training procedure per iteration. Given that dataloader2 can not fit into one single batch, I manually backward the gradients and at the end of the iterations over dataloader2 I perform the step.

If I remove the iteration procedure over dataloader2, everything works fine.

@ptrblck can you check this please ?

Could you try to wrap the dataloader2 loop into a separate function?
Since Python uses function scoping, all references will be deleted once you exit the function.
Would it be possible to move the dataloader2 loop outside of the dataloader1 loop?
This would allow you to write both loops into functions, which might reduce unnecessary storage of tensors.

Actually this is what I do but wanted my above code to be more readable such that I something like that:

for epoch in epochs:
   train1(dataloader1, dataloader2, model, optimizer, writer, epoch)
   eval(...)
   metric(...)
def train1(dataloader1, dataloader2, model, optimizer, writer, epoch):
   for iter1, data1 in enumerate(dataloader1): # batch size of dataloader1 is small because no DataParallel
      optimizer.zero_grad()
      class_loss1, reg_loss1 = model(input1, output1) # model is retinanet
      total1 = class_loss1 + reg_loss1
      total1.backward()
      writer.add_scalars('Loss', {'Loss1': total1.item()}, iter1)
      del total1
      if do_train2:
          train2(dataloader2, model, writer)
      optimizer.step()
def train2(dataloader2, model, writer):
   acc_loss = 0
   for iter2, data2 in enumerate(dataloader2): size dl2 >> dl1 -> batch size = 1 and we accumulate the gradient manually
       with torch.no_grad():
          scores, classification, predicted = model(input2) # inference
       class_loss2, reg_loss2 = model(input2, output2) # model is retinanet
       total2 = class_loss2 + reg_loss2
       total2.backward()
       acc_loss += total2.detach()
       del total2

   writer.add_scalars('Loss', {'Loss2': acc_loss}, iter1)

Thanks for the update.
So back to the first question. You mentioned that you are only running out of memory, if no annotations are found for the current image? I assume this is a property of some images?
If that’s the case, let’s take a look at the posted FocalLoss implementation and search for differences between the different paths for annotated images and one without annotations.

Here is my FocalLoss. I assume that the problem appears only when there are no annotations but I can not be 100% sure given that my dataloader1 does not have images without annotations but it is the case for dataloader2. I already tried del regression at the end of the part considering no annotations but did not help.

import numpy as np
import torch
import torch.nn as nn

import sys

def calc_iou(a, b):
    area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])

    iw = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 0])
    ih = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 1])

    iw = torch.clamp(iw, min=0)
    ih = torch.clamp(ih, min=0)

    ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih

    ua = torch.clamp(ua, min=1e-8)

    intersection = iw * ih

    IoU = intersection / ua

    return IoU

class FocalLoss(nn.Module):
    #def __init__(self):

    def forward(self, classifications, regressions, anchors, annotations):
        alpha = 0.25
        gamma = 2.0
        batch_size = classifications.shape[0]
        classification_losses = []
        regression_losses = []

        anchor = anchors[0, :, :]

        anchor_widths  = anchor[:, 2] - anchor[:, 0]
        anchor_heights = anchor[:, 3] - anchor[:, 1]
        anchor_ctr_x   = anchor[:, 0] + 0.5 * anchor_widths
        anchor_ctr_y   = anchor[:, 1] + 0.5 * anchor_heights

        for j in range(batch_size):

            classification = classifications[j, :, :]
            regression = regressions[j, :, :]

            bbox_annotation = annotations[j, :, :]
            bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]

            # classification in [1e-4, 1-1e-4]
            classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)

            # if no annotation
            if bbox_annotation.shape[0] == 0:
                if torch.cuda.is_available():

                    alpha_factor = torch.ones(classification.shape).cuda() * alpha

                    alpha_factor = 1. - alpha_factor
                    focal_weight = classification
                    focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

                    bce = -(torch.log(1.0 - classification))

                    cls_loss = focal_weight * bce

                    regression_losses.append(torch.tensor(0).float().cuda())
                    classification_losses.append(cls_loss.sum())

                else:
                    alpha_factor = torch.ones(classification.shape) * alpha

                    alpha_factor = 1. - alpha_factor
                    focal_weight = classification
                    focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

                    bce = -(torch.log(1.0 - classification))

                    cls_loss = focal_weight * bce

                    regression_losses.append(torch.tensor(0).float())
                    classification_losses.append(cls_loss.sum())

                continue

            IoU = calc_iou(anchors[0, :, :], bbox_annotation[:, :4]) # num_anchors x num_annotations

            IoU_max, IoU_argmax = torch.max(IoU, dim=1) # num_anchors x 1

            #import pdb
            #pdb.set_trace()

            # compute the loss for classification
            targets = torch.ones(classification.shape) * -1
            
            if torch.cuda.is_available():
                targets = targets.cuda()

            # targets < 0.4 set to 0
            targets[torch.lt(IoU_max, 0.4), :] = 0

            # print(targets.shape[0] - (targets==0).sum())

            # indices = IoU > 0.5
            positive_indices = torch.ge(IoU_max, 0.5)

            num_positive_anchors = positive_indices.sum()

            assigned_annotations = bbox_annotation[IoU_argmax, :]

            targets[positive_indices, :] = 0
            targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1

            if torch.cuda.is_available():
                alpha_factor = torch.ones(targets.shape).cuda() * alpha
            else:
                alpha_factor = torch.ones(targets.shape) * alpha

            alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
            focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
            focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

            bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))

            # cls_loss = focal_weight * torch.pow(bce, gamma)
            cls_loss = focal_weight * bce

            if torch.cuda.is_available():
                cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda())
            else:
                cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape))

            classification_losses.append(cls_loss.sum()/torch.clamp(num_positive_anchors.float(), min=1.0))

            # compute the loss for regression

            if positive_indices.sum() > 0:
                assigned_annotations = assigned_annotations[positive_indices, :]

                anchor_widths_pi = anchor_widths[positive_indices]
                anchor_heights_pi = anchor_heights[positive_indices]
                anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
                anchor_ctr_y_pi = anchor_ctr_y[positive_indices]

                gt_widths  = assigned_annotations[:, 2] - assigned_annotations[:, 0]
                gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
                gt_ctr_x   = assigned_annotations[:, 0] + 0.5 * gt_widths
                gt_ctr_y   = assigned_annotations[:, 1] + 0.5 * gt_heights

                # clip widths to 1
                gt_widths  = torch.clamp(gt_widths, min=1)
                gt_heights = torch.clamp(gt_heights, min=1)

                targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
                targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
                targets_dw = torch.log(gt_widths / anchor_widths_pi)
                targets_dh = torch.log(gt_heights / anchor_heights_pi)
                targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh))
                targets = targets.t()


                if torch.cuda.is_available():
                    targets = targets/torch.Tensor([[0.1, 0.1, 0.2, 0.2]]).cuda()
                else:
                    targets = targets/torch.Tensor([[0.1, 0.1, 0.2, 0.2]])

                negative_indices = 1 + (~positive_indices)

                regression_diff = torch.abs(targets - regression[positive_indices, :])

                regression_loss = torch.where(
                    torch.le(regression_diff, 1.0 / 9.0),
                    0.5 * 9.0 * torch.pow(regression_diff, 2),
                    regression_diff - 0.5 / 9.0
                )
                regression_losses.append(regression_loss.mean())
            else:
                if torch.cuda.is_available():
                    regression_losses.append(torch.tensor(0).float().cuda())
                else:
                    regression_losses.append(torch.tensor(0).float())

        return torch.stack(classification_losses).mean(dim=0, keepdim=True), torch.stack(regression_losses).mean(dim=0, keepdim=True)

PS: @ptrblck this happens when there are no annotations as ground truth and not when no annotations are found by the model as you mentionned.

PS2: in the implementation, the original author use torch.nn.DataParallel even though i only have one gpu. Does that make a difference ?

Could you add some fake annotations fir dataloader2 just to potentially exclude the lack of target annotations as a root cause?

If you are passing a single GPU id (or only one GPU is found), nn.DataParallel will just use the single GPU as seen in these lines of code.

Even if it is not the good way to do it, I will check that output2 is not empty to do:

such that I’m sure that there is always at least one annotation but i miss the fact that the focal loss should still perform the log-loss where every anchor is considered as background.

I guess that accumulate the gradients manually in train2 should not be a problem right ?

Gradient accumulation shouldn’t be a problem, as no new .grad attributes will be created and the already created ones will be used.

Did the check for empty annotations work, i.e. are you still running out of memory?

I started an experiment yesterday and it is still running and it never went that far.
If the training finishes I will check the solution.

I guess I will never know why it runs OOM when there are no annotations.

Thank you for the help!