Training Mask R-CNN yields 'RuntimeError: stack expects each tensor to be equal size'

I am trying to train the maskrcnn_resnet50_fpn model on a custom dataset with only two classes (including background), by copying what they’re doing in the ‘TorchVision Instance Segmentation Finetuning’ Tutorial (Google Colab). I get the error RuntimeError: stack expects each tensor to be equal size, but got [11, 1024, 1024] at entry 0 and [47, 1024, 1024] at entry 1. It seems to me that this is saying that the mask tensors for two different images have to be of the same size, i.e. that all images must contain the same number of instances, which makes no sense.

I’ve written the following minimal example, where all the 50 images and corresponding masks just contain zeros (i.e. are totally black). Image number i has i masks. The bounding boxes cover the entire images.
The engine module is found in vision/references/detection at main · pytorch/vision · GitHub.

import torch
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torch.utils.data import DataLoader
from engine import train_one_epoch

class WeirdDataset(torch.utils.data.Dataset):
    def __len__(self):
        return 50

    def __getitem__(self, idx):
        img = torch.zeros(1, 1024, 1024)

        target = {}
        bbox = torch.tensor([1, 1, 1024, 1024], dtype=torch.float)
        target['masks'] = torch.zeros(idx, 1024, 1024, dtype=torch.uint8)
        target['boxes'] = bbox.repeat(idx, 1)
        target['labels'] = torch.ones(idx, dtype=torch.int64)

        return img, target

model = maskrcnn_resnet50_fpn(num_classes = 2)
model.train()
train_dataset = WeirdDataset()

train_data_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

for epoch in range (5):
    train_one_epoch(model, optimizer, train_data_loader, device='cpu', epoch=epoch, print_freq=1)

If you have two tensors of size [11, 1024, 1024] and [47, 1024, 1024], you’ll want to use torch.cat instead,

a=torch.randn(11,1024,1024)
b=torch.randn(47,1024,1024)
c=torch.cat((a,b), dim=0)
print(c.shape) #returns [58,1024,1024]`