A problem with torch.stack()

DataLoader produces an error.


The code for DataLoader is below:

class PandaDataset(Dataset):
    def __init__(self, path, train, transform):
        self.path = path
        self.names = list(pd.read_csv(train).image_id)
        self.labels = list(pd.read_csv(train).isup_grade)
        self.transform = transform
        
    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        label = self.labels[idx]
        name = self.names[idx]
        img = skimage.io.MultiImage(os.path.join(self.path,name+'.tiff'))[-1]
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        list_of_images = []
        tiles = tile(img)
        for img in tiles:

            if x:
                augmented = get_transforms(data='valid')(image=img)
                img = augmented['image']

            list_of_images.append(img)

        new_tiles=torch.stack(list_of_images, axis=0)
    
        
        return new_tiles, torch.Tensor(label)

when I tried to debug it with code below outside of DataLoader it works fine.

img = skimage.io.MultiImage(os.path.join(train_images,'0005f7aaab2800f6170c399693a96917'+'.tiff'))[-1]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

list_of_images = []
tiles = tile(img)
for img in tiles:

    if x:
        augmented = get_transforms(data='train')(image=img)
        img = augmented['image']

    list_of_images.append(img)

new_tiles=torch.stack(list_of_images, axis=0)

def tile simply crops and image to 12 images and append them to numpy array.
Thanks.

Could you print the shapes of all images inside the list_of_images before calling torch.stack?
It seems that some dimensions create a shape mismatch. Are you dealing with differently shapes images?

sure.

All images have the same dimensions, as i said before, when i am trying to debug it outside of DataLoader it works fine.

Since all images have the same shape, the torch.stack call should work:

list_of_images = [torch.randn(3, 128, 128) for _ in range(10)]
x = torch.stack(list_of_images, 0)
print(x.shape)
> torch.Size([10, 3, 128, 128])

Also, a Dataset with equal shapes works fine:

class PandaDataset(Dataset):
    def __init__(self):
        pass
    
    def __len__(self):
        return 20

    def __getitem__(self, idx):
        list_of_images = []
        tiles = torch.randn(10, 3, 128, 128)
        for img in tiles:
            list_of_images.append(img)
        new_tiles=torch.stack(list_of_images, axis=0)    
        
        return new_tiles, torch.tensor(0)

dataset = PandaDataset()
loader = DataLoader(dataset, batch_size=5)

for data, target in loader:
    print(data.shape)

> torch.Size([5, 10, 3, 128, 128])
torch.Size([5, 10, 3, 128, 128])
torch.Size([5, 10, 3, 128, 128])
torch.Size([5, 10, 3, 128, 128])
1 Like

Thanks, somehow I was able to fix it. :sweat_smile: