Dataloader collate fn throws `RuntimeError: stack expects each tensor to be equal size` in response to variable number of bounding boxes

I am using torchvision.models.detection.retinanet_resnet50_fpn which expects the labels to be a list of dictionaries each containing:

  1. boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
  2. labels (Int64Tensor[N]): the class label for each ground-truth box

So normally each dict which belongs to a unique training sample would contain a different number boxes and labels.

In my custom Dataset class:

class COCODataset(Dataset):
    def __init__(self, root, annFile, transforms):
        self.root = Path(root)
        self.annFile = Path(annFile)
        self.transforms = transforms
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_path = self.root / self.imgs[idx]
        img ="RGB")
        target = self.targets[idx]

        target["boxes"] = torch.as_tensor(target["boxes"], dtype=torch.float32)
        target["labels"] = torch.as_tensor(target["labels"], dtype=torch.int64)

        if self.transforms is not None:
            img = self.transforms(img)

        return img, target

When trying to retrieve a sample directly from the custom dataset it works fine and returns x and y as:

(tensor([[[0.5977, 0.6082, 0.6157,  ..., 0.5798, 0.5562, 0.5738],
          [0.4960, 0.6187, 0.6154,  ..., 0.9062, 0.9111, 0.9098]]]),

 {'boxes': tensor([[250.8200, 168.2600, 320.9300, 233.1400],
          [285.5500, 370.5600, 297.6200, 389.7700]]),
  'labels': tensor([1, 1])})

However when wrapping the dataset in a Dataloader and it throws RuntimeError: stack expects each tensor to be equal size, but got [10, 4] at entry 0 and [3, 4] at entry 1. Any idea how to deal with this?

However in this official tutorial I believe they are using a variable number of bounding boxes as well when building the custom dataset so it must be possible.

This issue was solved by passing this collate_fn to the DataLoader.