I am using the ETCI torchgeo Dataset for an instance segmentation problem.
I created a custom dataset for it following this tutorial: TorchVision Object Detection Finetuning Tutorial — PyTorch Tutorials 2.5.0+cu124 documentation
Here is minimally reproducible code:
from torch.utils.data import DataLoader
from torchgeo import datasets
import os
import torch
from torchvision.ops.boxes import masks_to_boxes
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F
class ETCICustomDataset(torch.utils.data.Dataset):
def __init__(self, root:str, split:str, transforms, mask_idx, download:bool=False, checksum:bool=False):
# load all image files, sorting them to
# ensure that they are aligned
self.dataset = datasets.ETCI2021(root=root, split=split,
transforms=transforms, download=download, checksum=checksum)
self.mask_idx = mask_idx
def __getitem__(self, idx):
# load images and masks
img = self.dataset[idx]['image']
mask = self.dataset[idx]['mask'][self.mask_idx]
# instances are encoded as different colors
obj_ids = torch.unique(mask)
# first id is the background, so remove it
obj_ids = obj_ids[1:]
num_objs = len(obj_ids)
# split the color-encoded mask into a set
# of binary masks
masks = (mask == obj_ids[:, None, None]).to(dtype=torch.uint8)
# get bounding box coordinates for each mask
boxes = masks_to_boxes(masks)
# there is only one class
labels = torch.ones((num_objs,), dtype=torch.int64)
image_id = idx
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# suppose all instances are not crowd
iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
target = {}
target["boxes"] = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=F.get_size(img))
target["masks"] = tv_tensors.Mask(masks)
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
return img, target
def __len__(self):
return len(self.dataset)
def collate_fn(data):
img, target = data
zipped = zip(img, target)
return list(zipped)
def collate_fn2(data):
# want to stack the images and then combine the dicts
imgs = [x[0] for x in data]
images = torch.stack(imgs)
target_dict ={}
dicts = [x[1] for x in data]
for key in dicts.keys():
target_dict[key] = 0
return (images, )
dataset = ETCICustomDataset(root='data/ETCI', split='train', mask_idx=0,
transforms=None, download=False, checksum=False)
ldr = DataLoader(dataset, 8, False, num_workers=0)
i = 0
for sample in ldr:
print(i)
i+=8
The for loop at the bottom is just to try to find which image/sample was failing but it changes based on the batch size.
Following DataLoader gives:stack expects each tensor to be equal size,due to different image has different objects number and RuntimeError: stack expects each tensor to be equal size, but got [3, 224, 224] at entry 0 and [3, 224, 336] at entry 3
I tried to make custom collate functions but, the default one has the correct behavior it just falls occasionally whereas my custom ones do not create the type/shape of the data I want.
Can anyone help me create a collate fn that works like the default one but doesn’t fail with this bug or perhaps help me understand where this bug comes from?
I hypothesize that it happens when the collate fn tries to stack the bbox in the target dict and the bbox is empty so its shape is [0,4] and then the other one is [1,4] but I am not certain.