Hello,
I have a working model that will train without any problems. I am using Pytorch 1.5.1 on a GPU with CUDA 10.2. Then I wanted to add a multi task loss factor which is causing problems. I can train the model for ~3500 (+/- 400) iterations and then I get a Cuda out of memory error. Before this the cached memory is around 8GB and I use a 2080ti with 11GB VRAM. I can not see a memory leak till then, the cached memory seems constant. At ~3500 it seems to spike in the forward portion of the network. The error only appears when I set my require_grad()
to true for my multi-task loss weights (see code below). I have searched the internet for answers but most of them boil down to people not using .item()
when storing tensors to their statistics or saving them somehow but I dont think I do this.
iter: 3850, MT (d_boxes-s: 1.0000, d_cls-adjusted: 1.0389, d_cls-s: 1.1078, d_reg-adjusted: -1.4250, d_reg-s: -3.2818), acc (bg: 1.00, fg: 0.84, iou: 0.78), loss (bbox_2d: 0.5207, bbox_3d: 1.0714, cls: 0.1337, depth_cls: 1.4683, depth_reg: 0.0163), misc (ry: 1.17, z: 1.33), dt: 0.81, eta: 44.1h
Traceback (most recent call last):
File "scripts/train.py", line 388, in <module>
main(sys.argv[1:])
File "scripts/train.py", line 220, in main
cls, prob, bbox_2d, bbox_3d, feat_size, x_d_cls, x_d_reg = rpn_net(images.cuda())
File "/disk/no_backup/---/D4LCN/env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/disk/no_backup/---/D4LCN/env/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 153, in forward
return self.module(*inputs[0], **kwargs[0])
File "/disk/no_backup/---/D4LCN/env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/disk/vanishing_data/---/D4LDCoutput/output/learn_depth_config/Adaptive_block2_resnet_dilate_depth50_batch2_dropoutearly0_5_lr0_005_onecycle_iter200000_2020_07_29_05_01_16/resnet_dilate_depth.py", line 168, in forward
x_d_cls = F.interpolate(x_d_cls, size=inp_shape, mode="bilinear", align_corners=True)
File "/disk/no_backup/---/D4LCN/env/lib/python3.6/site-packages/torch/nn/functional.py", line 3013, in interpolate
scale_factor_list[0], scale_factor_list[1])
RuntimeError: CUDA out of memory. Tried to allocate 144.00 MiB (GPU 0; 10.76 GiB total capacity; 9.54 GiB already allocated; 51.00 MiB free; 9.72 GiB reserved in total by PyTorch)
I initialize the weights for the loss like this:
task_weights = {}
update_weights = True if conf.task_weight_policy == 'update' else False
#default setting is that all weights start at 1.0
task_weights['d_cls'] = torch.tensor(conf.task_weight_init_d_cls).float()
task_weights['d_cls'] = task_weights['d_cls'].cuda()
task_weights['d_cls'] = task_weights['d_cls'].requires_grad_(update_weights)
#when I set update_weights to False the model does not crash.
#Update weights only affects the task weights
task_weights['d_reg'] = ...
I do this for every task (3 in total, “d_cls”, “d_reg”, “d_boxes”). I want to use these weights to change the importance of other losses.
I add them to the optimizer like this:
objective_params = list(network.parameters()) + list(task_weights.values())
optimizer = torch.optim.SGD(objective_params, lr=lr, momentum=mo, weight_decay=wd)
The calculation of the lass happens in a class which inherits nn.Module
. It has a forward function which receives all the important tensors including task_weights
:
def forward(self, cls, prob, bbox_2d, bbox_3d, imobjs, feat_size, bbox_vertices, corners_3d, depth_cls = None, depth_reg, depth_label, task_weights):
At first it calculates a losses for 3d Bounding Boxes, 2d Bounding Boxes etc. Loss is initialized like this:
loss = torch.tensor(0).type(torch.cuda.FloatTensor)
and then it is summed up like this:
loss += loss_cls
Then I calculate the loss for the Depth Map regression and classification. The are not summed to the loss
tensor just yet.
A MultiTask loss was implemented like this:
label_reg = depth_label['d_reg'].cuda()
label_cls = depth_label['d_cls'].cuda()
reg_loss = self.masked_mse(depth_reg, label_reg)
cls_loss = self.cross_entropy2d(depth_cls, label_cls)
stats.append({'name': "depth_reg", 'val': reg_loss.item(), 'format': '{:0.4f}', 'group': 'loss'})
stats.append({'name': "depth_cls", 'val': cls_loss.item(), 'format': '{:0.4f}', 'group': 'loss'})
mt_loss = torch.tensor(0).type(torch.cuda.FloatTensor
for task in task_weights:
s = task_weights[task]
r = s * 0.5
stats.append({'name': task + "-s", 'val': s.item(), 'format': '{:0.4f}', 'group': 'MT'})
if task in ["d_cls"]:
w = torch.exp(-s)
mt_loss += cls_loss * w + r
stats.append({'name': task + "-adjusted", 'val': cls_loss.item() * w.item() + r.item(), 'format': '{:0.4f}', 'group': 'MT'})
elif task in ["d_reg"]:
.....
elif task in ["d_boxes"]:
w = 0.5 * torch.exp(-s)
mt_loss += loss * w + r
# use the loss which was summed up by 2dBB loss, 3D BB loss etc...
stats.append({'name': task + "-adjusted", 'val': loss.item() * w.item() + r.item(), 'format': '{:0.4f}', 'group': 'MT'})
# overwrite loss with MT_loss which includes the factored loss
loss = mt_loss
return loss, stats
Then the main function calls
det_loss, det_stats = criterion_det(cls, prob, bbox_2d, bbox_3d, imobjs, feat_size, depth_cls = x_d_cls, depth_reg = x_d_reg, depth_label = depths, task_weights = task_weights)
total_loss = det_loss
total_loss.backward()
optimizer.step()
optimizer.zero_grad()
Afterwards the task_weights are only used when saving them:
if (iteration + 1) % conf.snapshot_iter == 0 and iteration > start_iter:
save_checkpoint(optimizer, scheduler, rpn_net, paths.weights, (iteration + 1))
save_task_weights(paths.weights, (iteration + 1), task_weights)
def save_task_weights(weights_dir, iteration, task_weights):
state = {"task_weights": task_weights}
task_weights_path = os.path.join(weights_dir, 'task_weights_{}_pkl'.format(iteration))
torch.save(state, task_weights_path)
To summarize: The model works fine if the task_weights
is not backpropagated and stays constant. The cuda error only appears when I set requires_grad_(True)
. I included every part of the code that I think is relevant to my problem but I can post more if you need it. I hope someone here can spot my error.
I am thankful for any suggestions.