I created a dataset for a very small set of images (408 images download link is here).
it contains a csv file that has image file names and labels. and this is the class I made :
# we use csv for reading csv file
import csv
# we use PIL.Image for reading an image
import PIL.Image as Image
import os
class AnimeMTLDataset(torch.utils.data.Dataset):
def __init__(self, image_folder, csv_file_path, transformations, is_training_set = True) :
super().__init__()
self.path = csv_file_path
self.transforms = transformations
self.is_training_set = is_training_set
self.image_folder = image_folder
self.length = -1
if self.is_training_set:
# read the csv file into a dictionary
with open(csv_file_path, 'r') as csv_file :
csv_reader = csv.reader(csv_file)
# to skip header we simply do
next(csv_reader)
self.dataset = {}
for i, line in enumerate(csv_reader):
self.dataset[i] = line
self.length = len( self.dataset)
else:
self.image_folder = os.path.join(self.image_folder,'test')
self.length = len(os.listdir(self.image_folder))
def _format_input(self, input_str, one_hot=False):
one_hot_tensor = torch.tensor([float(i) for i in input_str])
if one_hot:
return one_hot_tensor
if one_hot_tensor.size(0) > 1 :
return torch.argmax(one_hot_tensor)
else:
return one_hot_tensor[0].int()
# lets create the corsponding labels for each category
def _parse_labels(self, input_str):
# white,red,green,black,blue,purple,gold,silver
colors = self._format_input(input_str[4:11], True)
# gender_Female,gender_Male
genders = self._format_input(input_str[12:13])
# region_Asia,region_Egypt, region_Europe, region_Middle East
regions = self._format_input(input_str[14:17])
# fighting_type_magic, fighting_type_melee, fighting_type_ranged
fighting_styles = self._format_input(input_str[18:20])
# alignment_CE, alignment_CG, alignment_CN, alignment_LE,
# alignment_LG, alignment_LN, alignment_NE, alignment_NG, alignment_TN
alignments = self._format_input(input_str[21:])
return colors, genders, regions, fighting_styles, alignments
def __getitem__(self, index):
if self.is_training_set:
img_path = self.dataset[index][1]
labels = self._parse_labels(self.dataset[index])
# image files must be read as bytes so we use 'rb' instead of simply 'r'
# which is used for text files
with open(os.path.join(self.image_folder, img_path), 'rb') as img_file:
# since our datasets include png images, we need to make sure
# we read only 3 channels and not more!
img = Image.open(img_file).convert('RGB')
print(img_path)
# apply the transformations
img = self.transforms(img)
print(img.shape)
return img, labels
else:
for img_path in os.listdir(self.image_folder):
with open(os.path.join(self.image_folder, img_path), 'rb') as img_file:
img = Image.open(img_file).convert('RGB')
# apply the transformations
img = self.transforms(img)
return img, None
def __len__(self):
return self.length
transformations = transforms.Compose([transforms.Resize(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
anime_dataset = AnimeMTLDataset(image_folder = 'mtl_dataset',
csv_file_path = r'mtl_dataset\fgo_multiclass_labels.csv',
transformations=transformations)
# lets test our dataset class and see if it works ok:
#unnormalize
def unnormalize(img):
img = img.detach().numpy().transpose(1,2,0)
return img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
#training:
print('dataset size: {}'.format(len(anime_dataset)))
img, labels = anime_dataset[0]
plt.imshow(unnormalize(img))
this works. but when I try to use torch.utils.data.SubsetRandomSampler()
to create a validation set as well, or even a plain simple dataloader with no sampler, it fails with the error message :
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
in
21
22 # test
---> 23 imgs, labels = next(iter(dataloader_train))
24 print(imgs[0].shape)
25 plt.imshow(unnormalize(imgs[0]))
~\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
558 if self.num_workers == 0: # same-process loading
559 indices = next(self.sample_iter) # may raise StopIteration
--> 560 batch = self.collate_fn([self.dataset[i] for i in indices])
561 if self.pin_memory:
562 batch = _utils.pin_memory.pin_memory_batch(batch)
~\Anaconda3\lib\site-packages\torch\utils\data\_utils\collate.py in default_collate(batch)
66 elif isinstance(batch[0], container_abcs.Sequence):
67 transposed = zip(*batch)
---> 68 return [default_collate(samples) for samples in transposed]
69
70 raise TypeError((error_msg_fmt.format(type(batch[0]))))
~\Anaconda3\lib\site-packages\torch\utils\data\_utils\collate.py in (.0)
66 elif isinstance(batch[0], container_abcs.Sequence):
67 transposed = zip(*batch)
---> 68 return [default_collate(samples) for samples in transposed]
69
70 raise TypeError((error_msg_fmt.format(type(batch[0]))))
~\Anaconda3\lib\site-packages\torch\utils\data\_utils\collate.py in default_collate(batch)
41 storage = batch[0].storage()._new_shared(numel)
42 out = batch[0].new(storage)
---> 43 return torch.stack(batch, 0, out=out)
44 elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
45 and elem_type.__name__ != 'string_':
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 224 and 316 in dimension 2 at ..\aten\src\TH/generic/THTensor.cpp:711
So the following snippet fails and causes the previous error message :
# lets create a validation and training set
import numpy as np
import torch.utils.data as data
samples_count = len(anime_dataset)
all_samples_indexes = list(range(samples_count))
np.random.shuffle(all_samples_indexes)
val_ratio = 0.2
val_end = int(samples_count * 0.2)
val_indexes = all_samples_indexes[0:val_end]
train_indexes = all_samples_indexes[val_end:]
assert len(val_indexes) + len(train_indexes) == samples_count , 'the split is not valid'
sampler_train = data.SubsetRandomSampler(train_indexes)
sampler_val = data.SubsetRandomSampler(val_indexes)
dataloader_train = data.DataLoader(anime_dataset, batch_size = 32, sampler = sampler_train)
dataloader_val = data.DataLoader(anime_dataset, batch_size = 32, sampler = sampler_val)
# test
imgs, labels = next(iter(dataloader_train))
print(imgs[0].shape)
plt.imshow(unnormalize(imgs[0]))
What is wrong and what am I missing?
Thank you all in advance