Why doesnt my dataset work with batches bigger than 1?

I created a dataset for a very small set of images (408 images download link is here).
it contains a csv file that has image file names and labels. and this is the class I made :

# we use csv for reading csv file
import csv
# we use PIL.Image for reading an image
import PIL.Image as Image
import os

class AnimeMTLDataset(torch.utils.data.Dataset):
    def __init__(self, image_folder, csv_file_path, transformations, is_training_set = True) :
        super().__init__()
        
        self.path = csv_file_path 
        self.transforms = transformations
        self.is_training_set = is_training_set
        self.image_folder = image_folder
        self.length = -1

        if self.is_training_set:
            # read the csv file into a dictionary
            with open(csv_file_path, 'r') as csv_file : 
                csv_reader = csv.reader(csv_file)
                # to skip header we simply do 
                next(csv_reader)

                self.dataset = {}
                for i, line in enumerate(csv_reader):
                    self.dataset[i] = line
            self.length = len( self.dataset)
        else:
            self.image_folder = os.path.join(self.image_folder,'test')
            self.length = len(os.listdir(self.image_folder))

    def _format_input(self, input_str, one_hot=False):
        one_hot_tensor = torch.tensor([float(i) for i in input_str])
        if one_hot: 
            return one_hot_tensor 
        if one_hot_tensor.size(0) > 1 : 
            return torch.argmax(one_hot_tensor)
        else:
            return one_hot_tensor[0].int()
        
    # lets create the corsponding labels for each category
    def _parse_labels(self, input_str):
        # white,red,green,black,blue,purple,gold,silver
        colors = self._format_input(input_str[4:11], True)            
        # gender_Female,gender_Male
        genders = self._format_input(input_str[12:13])        
        # region_Asia,region_Egypt, region_Europe, region_Middle East  
        regions = self._format_input(input_str[14:17])        
        # fighting_type_magic, fighting_type_melee, fighting_type_ranged
        fighting_styles = self._format_input(input_str[18:20])          
        # alignment_CE, alignment_CG, alignment_CN, alignment_LE,
        # alignment_LG, alignment_LN, alignment_NE, alignment_NG, alignment_TN
        alignments = self._format_input(input_str[21:])  
        return colors, genders, regions, fighting_styles, alignments


    def __getitem__(self, index):
        if self.is_training_set:
            img_path = self.dataset[index][1]
            labels = self._parse_labels(self.dataset[index])
            # image files must be read as bytes so we use 'rb' instead of simply 'r' 
            # which is used for text files
            with open(os.path.join(self.image_folder, img_path), 'rb') as img_file:
                # since our datasets include png images, we need to make sure
                # we read only 3 channels and not more!
                img = Image.open(img_file).convert('RGB')
                print(img_path)
                # apply the transformations 
                img = self.transforms(img)
                print(img.shape)
                return img, labels
        else:
            for img_path in os.listdir(self.image_folder):
                with open(os.path.join(self.image_folder, img_path), 'rb') as img_file:
                    img = Image.open(img_file).convert('RGB')
                    # apply the transformations 
                    img = self.transforms(img)
                    return img, None

    def __len__(self):
        return self.length

transformations = transforms.Compose([transforms.Resize(224),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406],
                                                           [0.229, 0.224, 0.225])
                                      ])
anime_dataset = AnimeMTLDataset(image_folder = 'mtl_dataset',
                                csv_file_path = r'mtl_dataset\fgo_multiclass_labels.csv',
                                transformations=transformations)
# lets test our dataset class and see if it works ok: 
#unnormalize
def unnormalize(img):
    img = img.detach().numpy().transpose(1,2,0)
    return img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406] 

#training: 
print('dataset size: {}'.format(len(anime_dataset)))
img, labels = anime_dataset[0]
plt.imshow(unnormalize(img))

this works. but when I try to use torch.utils.data.SubsetRandomSampler() to create a validation set as well, or even a plain simple dataloader with no sampler, it fails with the error message :

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
 in 
     21 
     22 # test
---> 23 imgs, labels = next(iter(dataloader_train))
     24 print(imgs[0].shape)
     25 plt.imshow(unnormalize(imgs[0]))

~\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
    558         if self.num_workers == 0:  # same-process loading
    559             indices = next(self.sample_iter)  # may raise StopIteration
--> 560             batch = self.collate_fn([self.dataset[i] for i in indices])
    561             if self.pin_memory:
    562                 batch = _utils.pin_memory.pin_memory_batch(batch)

~\Anaconda3\lib\site-packages\torch\utils\data\_utils\collate.py in default_collate(batch)
     66     elif isinstance(batch[0], container_abcs.Sequence):
     67         transposed = zip(*batch)
---> 68         return [default_collate(samples) for samples in transposed]
     69 
     70     raise TypeError((error_msg_fmt.format(type(batch[0]))))

~\Anaconda3\lib\site-packages\torch\utils\data\_utils\collate.py in (.0)
     66     elif isinstance(batch[0], container_abcs.Sequence):
     67         transposed = zip(*batch)
---> 68         return [default_collate(samples) for samples in transposed]
     69 
     70     raise TypeError((error_msg_fmt.format(type(batch[0]))))

~\Anaconda3\lib\site-packages\torch\utils\data\_utils\collate.py in default_collate(batch)
     41             storage = batch[0].storage()._new_shared(numel)
     42             out = batch[0].new(storage)
---> 43         return torch.stack(batch, 0, out=out)
     44     elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
     45             and elem_type.__name__ != 'string_':

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 224 and 316 in dimension 2 at ..\aten\src\TH/generic/THTensor.cpp:711

So the following snippet fails and causes the previous error message :

# lets create a validation and training set 
import numpy as np
import torch.utils.data as data

samples_count = len(anime_dataset)
all_samples_indexes = list(range(samples_count))
np.random.shuffle(all_samples_indexes)

val_ratio = 0.2
val_end = int(samples_count * 0.2)
val_indexes = all_samples_indexes[0:val_end]
train_indexes = all_samples_indexes[val_end:]
assert len(val_indexes) + len(train_indexes) == samples_count , 'the split is not valid' 

sampler_train = data.SubsetRandomSampler(train_indexes)
sampler_val = data.SubsetRandomSampler(val_indexes)

dataloader_train = data.DataLoader(anime_dataset, batch_size = 32, sampler = sampler_train)
dataloader_val = data.DataLoader(anime_dataset, batch_size = 32, sampler = sampler_val)

# test 
imgs, labels = next(iter(dataloader_train))
print(imgs[0].shape)
plt.imshow(unnormalize(imgs[0]))

What is wrong and what am I missing?
Thank you all in advance

Did you pass the transformations to your custom Dataset in the second example?

Yes, This is actually the whole code so far :

import torch
import torch.nn as nn 
import torch.nn.functional as F 
from torch import optim 
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt 
%matplotlib inline
# we use csv for reading csv file
import csv
# we use PIL.Image for reading an image
import PIL.Image as Image
import os

class AnimeMTLDataset(torch.utils.data.Dataset):
    def __init__(self, image_folder, csv_file_path, transformations, is_training_set = True) :
        super().__init__()
        
        self.path = csv_file_path 
        self.transforms = transformations
        self.is_training_set = is_training_set
        self.image_folder = image_folder
        self.length = -1

        if self.is_training_set:
            # read the csv file into a dictionary
            with open(csv_file_path, 'r') as csv_file : 
                csv_reader = csv.reader(csv_file)
                # to skip header we simply do 
                next(csv_reader)

                self.dataset = {}
                for i, line in enumerate(csv_reader):
                    self.dataset[i] = line
            self.length = len( self.dataset)
        else:
            self.image_folder = os.path.join(self.image_folder,'test')
            self.length = len(os.listdir(self.image_folder))

    def _format_input(self, input_str, one_hot=False):
        one_hot_tensor = torch.tensor([float(i) for i in input_str])
        if one_hot: 
            return one_hot_tensor 
        if one_hot_tensor.size(0) > 1 : 
            return torch.argmax(one_hot_tensor)
        else:
            return one_hot_tensor[0].int()
        
    # lets create the corsponding labels for each category
    def _parse_labels(self, input_str):
        # white,red,green,black,blue,purple,gold,silver
        colors = self._format_input(input_str[4:11], True)            
        # gender_Female,gender_Male
        genders = self._format_input(input_str[12:13])        
        # region_Asia,region_Egypt, region_Europe, region_Middle East  
        regions = self._format_input(input_str[14:17])        
        # fighting_type_magic, fighting_type_melee, fighting_type_ranged
        fighting_styles = self._format_input(input_str[18:20])          
        # alignment_CE, alignment_CG, alignment_CN, alignment_LE,
        # alignment_LG, alignment_LN, alignment_NE, alignment_NG, alignment_TN
        alignments = self._format_input(input_str[21:])  
        return colors, genders, regions, fighting_styles, alignments


    def __getitem__(self, index):
        if self.is_training_set:
            img_path = self.dataset[index][1]
            labels = self._parse_labels(self.dataset[index])
            # image files must be read as bytes so we use 'rb' instead of simply 'r' 
            # which is used for text files
            with open(os.path.join(self.image_folder, img_path), 'rb') as img_file:
                # since our datasets include png images, we need to make sure
                # we read only 3 channels and not more!
                img = Image.open(img_file).convert('RGB')
                # apply the transformations 
                img = self.transforms(img)
                return img, labels
        else:
            for img_path in os.listdir(self.image_folder):
                with open(os.path.join(self.image_folder, img_path), 'rb') as img_file:
                    img = Image.open(img_file).convert('RGB')
                    # apply the transformations 
                    img = self.transforms(img)
                    return img, None

    def __len__(self):
        return self.length

transformations = transforms.Compose([transforms.Resize(224),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406],
                                                           [0.229, 0.224, 0.225])
                                      ])
anime_dataset = AnimeMTLDataset(image_folder = 'mtl_dataset',
                                csv_file_path = r'mtl_dataset\fgo_multiclass_labels.csv',
                                transformations=transformations)
# lets test our dataset class and see if it works ok: 
#unnormalize
def unnormalize(img):
    img = img.detach().numpy().transpose(1,2,0)
    return img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406] 

#training: 
print('dataset size: {}'.format(len(anime_dataset)))
img, labels = anime_dataset[0]
plt.imshow(unnormalize(img))
#%%
anime_dataset_test = AnimeMTLDataset(image_folder = 'mtl_dataset',
                                csv_file_path = r'mtl_dataset\fgo_multiclass_labels.csv',
                                transformations=transformations, 
                                is_training_set =False)

print('Test dataset test : ')
print('dataset size: {}'.format(len(anime_dataset_test)))
img, _ = anime_dataset_test[0]
plt.imshow(unnormalize(img))
#%%
# now lets create a dataloader and carry on!
# lets create a validation and training set 
import numpy as np
import torch.utils.data as data

samples_count = len(anime_dataset)
all_samples_indexes = list(range(samples_count))
np.random.shuffle(all_samples_indexes)

val_ratio = 0.2
val_end = int(samples_count * 0.2)
val_indexes = all_samples_indexes[0:val_end]
train_indexes = all_samples_indexes[val_end:]
assert len(val_indexes) + len(train_indexes) == samples_count , 'the split is not valid' 

sampler_train = data.SubsetRandomSampler(train_indexes)
sampler_val = data.SubsetRandomSampler(val_indexes)

dataloader_train = data.DataLoader(anime_dataset, batch_size = 32, sampler = sampler_train)
dataloader_val = data.DataLoader(anime_dataset, batch_size = 32, sampler = sampler_val)

dataloader_train2 = data.DataLoader(anime_dataset, batch_size = 32)

# test 
imgs, labels = next(iter(dataloader_train2))
print(imgs[0].shape)
plt.imshow(unnormalize(imgs[0]))

Could you specify the size in Resize as a tuple?

transforms.Resize((224, 224))

as a single value will work differently, if your images are not quadratic.

1 Like

Thank you very much! that solved the issue!
but do you mind if I ask you to kindly explain a bit on what went wrong in the first case? what is my other option except resizing squarishly! ?

If you pass a single value, only the smaller size will be matched to this value and the other one will be resized accordingly to keep the ratio as equal as possible.
From the docs:

size ( sequence or int ) ā€“ Desired output size. If size is a sequence like (h, w), output size will be matched to this. If size is an int, smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to (size * height / width, size)

1 Like

Thank you very much, but I meant, how is that relevant to the error I get? why would that simply crash?
Its not like Iā€™m convolving or running any kind of operations except couple of seemingly hurtless transformations! this is strange to me

The collate function cannot call torch.stack on tensors with different shapes in two or more dimensions.
Basically, this is crashing:

x = [torch.randn(3, 224, 224), torch.randn(3, 224, 300)]
torch.stack(x)
> RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 224 and 300 in dimension 3 at /opt/conda/conda-bld/pytorch_1565272271120/work/aten/src/TH/generic/THTensor.cpp:689

, while this will work (as all image tensors are resized to the same size):

x = [torch.randn(3, 224, 224), torch.randn(3, 224, 224)]
torch.stack(x) # shape = [2, 3, 224, 224]
1 Like

Thank you very very much. I really appreciate your kind reply.
Have a wonderful day sir:)

1 Like