Setting requires_grad_(True) for MT-Loss weights causes CUDA out of memory error

Hello,

I have a working model that will train without any problems. I am using Pytorch 1.5.1 on a GPU with CUDA 10.2. Then I wanted to add a multi task loss factor which is causing problems. I can train the model for ~3500 (+/- 400) iterations and then I get a Cuda out of memory error. Before this the cached memory is around 8GB and I use a 2080ti with 11GB VRAM. I can not see a memory leak till then, the cached memory seems constant. At ~3500 it seems to spike in the forward portion of the network. The error only appears when I set my require_grad() to true for my multi-task loss weights (see code below). I have searched the internet for answers but most of them boil down to people not using .item() when storing tensors to their statistics or saving them somehow but I dont think I do this.

iter: 3850, MT (d_boxes-s: 1.0000, d_cls-adjusted: 1.0389, d_cls-s: 1.1078, d_reg-adjusted: -1.4250, d_reg-s: -3.2818), acc (bg: 1.00, fg: 0.84, iou: 0.78), loss (bbox_2d: 0.5207, bbox_3d: 1.0714, cls: 0.1337, depth_cls: 1.4683, depth_reg: 0.0163), misc (ry: 1.17, z: 1.33), dt: 0.81, eta: 44.1h
Traceback (most recent call last):
  File "scripts/train.py", line 388, in <module>
    main(sys.argv[1:])
  File "scripts/train.py", line 220, in main
    cls, prob, bbox_2d, bbox_3d, feat_size, x_d_cls, x_d_reg = rpn_net(images.cuda())
  File "/disk/no_backup/---/D4LCN/env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/disk/no_backup/---/D4LCN/env/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 153, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/disk/no_backup/---/D4LCN/env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/disk/vanishing_data/---/D4LDCoutput/output/learn_depth_config/Adaptive_block2_resnet_dilate_depth50_batch2_dropoutearly0_5_lr0_005_onecycle_iter200000_2020_07_29_05_01_16/resnet_dilate_depth.py", line 168, in forward
    x_d_cls = F.interpolate(x_d_cls, size=inp_shape, mode="bilinear", align_corners=True)
  File "/disk/no_backup/---/D4LCN/env/lib/python3.6/site-packages/torch/nn/functional.py", line 3013, in interpolate
    scale_factor_list[0], scale_factor_list[1])
RuntimeError: CUDA out of memory. Tried to allocate 144.00 MiB (GPU 0; 10.76 GiB total capacity; 9.54 GiB already allocated; 51.00 MiB free; 9.72 GiB reserved in total by PyTorch)

I initialize the weights for the loss like this:

        task_weights = {}
        update_weights = True if conf.task_weight_policy == 'update' else False
        #default setting is that all weights start at 1.0
        task_weights['d_cls'] = torch.tensor(conf.task_weight_init_d_cls).float()
        task_weights['d_cls'] = task_weights['d_cls'].cuda()
        task_weights['d_cls'] = task_weights['d_cls'].requires_grad_(update_weights)
        #when I set update_weights to False the model does not crash. 
        #Update weights only affects the task weights
        task_weights['d_reg'] = ...

I do this for every task (3 in total, “d_cls”, “d_reg”, “d_boxes”). I want to use these weights to change the importance of other losses.

I add them to the optimizer like this:

 objective_params = list(network.parameters()) + list(task_weights.values())
 optimizer = torch.optim.SGD(objective_params, lr=lr, momentum=mo, weight_decay=wd)

The calculation of the lass happens in a class which inherits nn.Module. It has a forward function which receives all the important tensors including task_weights:

    def forward(self, cls, prob, bbox_2d, bbox_3d, imobjs, feat_size, bbox_vertices, corners_3d, depth_cls = None, depth_reg, depth_label, task_weights):

At first it calculates a losses for 3d Bounding Boxes, 2d Bounding Boxes etc. Loss is initialized like this:

        loss = torch.tensor(0).type(torch.cuda.FloatTensor)

and then it is summed up like this:

         loss += loss_cls

Then I calculate the loss for the Depth Map regression and classification. The are not summed to the loss tensor just yet.

A MultiTask loss was implemented like this:

       label_reg = depth_label['d_reg'].cuda()
       label_cls = depth_label['d_cls'].cuda()
       reg_loss = self.masked_mse(depth_reg, label_reg)
       cls_loss = self.cross_entropy2d(depth_cls, label_cls)
               
       stats.append({'name': "depth_reg", 'val': reg_loss.item(), 'format': '{:0.4f}', 'group': 'loss'})
       stats.append({'name': "depth_cls", 'val': cls_loss.item(), 'format': '{:0.4f}', 'group': 'loss'})

       mt_loss =  torch.tensor(0).type(torch.cuda.FloatTensor

       for task in task_weights:
             s = task_weights[task]
             r = s * 0.5
             stats.append({'name': task + "-s", 'val': s.item(), 'format': '{:0.4f}', 'group': 'MT'})
             if task in ["d_cls"]:
                  w = torch.exp(-s)
                  mt_loss += cls_loss * w + r
                  stats.append({'name': task + "-adjusted", 'val': cls_loss.item() * w.item() + r.item(), 'format': '{:0.4f}', 'group': 'MT'})
             elif task in ["d_reg"]:
                  .....
             elif task in ["d_boxes"]:
                  w = 0.5 * torch.exp(-s)
                  mt_loss += loss * w + r
                  # use the loss which was summed up by 2dBB loss, 3D BB loss etc...
                  stats.append({'name': task + "-adjusted", 'val': loss.item() * w.item() + r.item(), 'format': '{:0.4f}', 'group': 'MT'})
       # overwrite loss with MT_loss which includes the factored loss
       loss = mt_loss 
       return loss, stats

Then the main function calls

det_loss, det_stats = criterion_det(cls, prob, bbox_2d, bbox_3d, imobjs, feat_size, depth_cls = x_d_cls, depth_reg = x_d_reg, depth_label = depths, task_weights = task_weights)
total_loss = det_loss
total_loss.backward()
optimizer.step()
optimizer.zero_grad()

Afterwards the task_weights are only used when saving them:

if (iteration + 1) % conf.snapshot_iter == 0 and iteration > start_iter:
            save_checkpoint(optimizer, scheduler, rpn_net, paths.weights, (iteration + 1))    
            save_task_weights(paths.weights, (iteration + 1), task_weights)
def save_task_weights(weights_dir, iteration, task_weights):
    state = {"task_weights": task_weights}
    task_weights_path = os.path.join(weights_dir, 'task_weights_{}_pkl'.format(iteration))
    torch.save(state, task_weights_path)

To summarize: The model works fine if the task_weights is not backpropagated and stays constant. The cuda error only appears when I set requires_grad_(True). I included every part of the code that I think is relevant to my problem but I can post more if you need it. I hope someone here can spot my error.

I am thankful for any suggestions.

Enabling the gradient calculation for additional parameters might store more intermediate activation tensors, if they are needed to compute the gradients, which could yield the out of memory issue.
You could either try to reduce the batch size or use e.g. torch.utils.checkpoint to trade compute for memory.