Torchvision MaskRCNN returning NaN losses in fp16?

If if pass the full model into fp16 at training mode the model return a dict of losses. Some losses are NaN values.

Has anyone make MaskRCNN work at FP16?

I would recommend to use automatic mixed precision training via torch.cuda.amp by installing the nightly binaries.
Simply transforming your model to float16 might easily create overflows/underflows and thus might yield NaNs.

The thing is that MaskRCNN return a dict of losses and not just a loss value

In addition looks like MaskRCNN doesn’t supports torch.cuda.amp.autocast:

~/Documents/test/seg/models/archs/mask_rcnn.py in mixed_precision_one_batch(self, i, b)
    186         with autocast():
    187             self.model.train()
--> 188             loss_dict = self.model(images,targets)
    189             if not self.training:
    190                 self.model.eval()

~/anaconda3/envs/seg/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    556             result = self._slow_forward(*input, **kwargs)
    557         else:
--> 558             result = self.forward(*input, **kwargs)
    559         for hook in self._forward_hooks.values():
    560             hook_result = hook(self, input, result)

~/anaconda3/envs/seg/lib/python3.7/site-packages/torchvision/models/detection/generalized_rcnn.py in forward(self, images, targets)
     68         if isinstance(features, torch.Tensor):
     69             features = OrderedDict([('0', features)])
---> 70         proposals, proposal_losses = self.rpn(images, features, targets)
     71         detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
     72         detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

~/anaconda3/envs/seg/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    556             result = self._slow_forward(*input, **kwargs)
    557         else:
--> 558             result = self.forward(*input, **kwargs)
    559         for hook in self._forward_hooks.values():
    560             hook_result = hook(self, input, result)

~/anaconda3/envs/seg/lib/python3.7/site-packages/torchvision/models/detection/rpn.py in forward(self, images, features, targets)
    486         proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
    487         proposals = proposals.view(num_images, -1, 4)
--> 488         boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
    489 
    490         losses = {}

~/anaconda3/envs/seg/lib/python3.7/site-packages/torchvision/models/detection/rpn.py in filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level)
    408             boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
    409             # non-maximum suppression, independently done per level
--> 410             keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
    411             # keep only topk scoring predictions
    412             keep = keep[:self.post_nms_top_n()]

~/anaconda3/envs/seg/lib/python3.7/site-packages/torchvision/ops/boxes.py in batched_nms(boxes, scores, idxs, iou_threshold)
     73     offsets = idxs.to(boxes) * (max_coordinate + 1)
     74     boxes_for_nms = boxes + offsets[:, None]
---> 75     keep = nms(boxes_for_nms, scores, iou_threshold)
     76     return keep
     77 

~/anaconda3/envs/seg/lib/python3.7/site-packages/torchvision/ops/boxes.py in nms(boxes, scores, iou_threshold)
     33         by NMS, sorted in decreasing order of scores
     34     """
---> 35     return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
     36 
     37 

RuntimeError: Unrecognized tensor type ID: Autocast

Since nms is a custom extension, you might need to disable autocasting for it as described here.
CC @mcarilli to add more information in case I’m missing something.

Let’s talk here https://github.com/pytorch/pytorch/issues/37735

Okey, let me know in github if you discover what is originating the issue and how to solve it!!!

Did you have time to look at it?