Problem with my input data structure using Torchvision RetinaNet? Minimal example included

An error occurs while trying to compute loss in the classification head. To the best I can tell, my input data is in the structure required by the torch vision RetinaNet. However, I still get this error that prevents me from training.

I’ve replaced the example minimally below. The error message you receive should be “index 1 is out of bounds for dimension 1 with size 1” in compute_loss(args) method of the classification head.

Thank you in advance for any help.

Minimal code:

''' Adapted from example in PyTorch code '''
import torch
import torchvision
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection import RetinaNet

''' Backbone '''
backbone = resnet_fpn_backbone('resnet18', pretrained=False, trainable_layers=4)
backbone.out_channels = 256

''' Anchor Generator '''
anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) 
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) 
anchor_generator = AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios)

''' Model '''
model = RetinaNet(backbone,

def __getitem__():
    img = torch.rand(3, 256, 256)
    bboxes = []
    bboxes = [[15, 15, 20, 20]]*20
    bboxes = torch.FloatTensor(bboxes)
    labels = torch.LongTensor(np.ones(len(bboxes), dtype=int))
    targets = {'boxes':bboxes, 'labels':torch.LongTensor(labels)}
    return img, targets

def retinanet_collate_fn(batch_size=2):
    img_batch = []
    targets_batch = []
    for i in range(batch_size):
        img, targets = __getitem__()
    return img_batch, targets_batch

img_batch, targets_batch = retinanet_collate_fn(batch_size=2)

outputs = model(img_batch, targets_batch)

I think the error is raised in the unexpected target containing ones (so two classes):

labels = torch.LongTensor(np.ones(len(bboxes), dtype=int))

while the model is initialized with num_classes=1. Change it to num_classes=2 and it should work.

1 Like

You are absolutely right. I assumed the background would have been class 0. That makes sense. Thank you.