I created a custom Dataset class for loading 3d float arrays as input and .png segmentation map for targets:
class ADE20K_SIFT(data.Dataset):
"""ADE20K
input is a 3d matrix and target is a .png image
Arguments:
root (string): filepath to ADE20K root folder.
image_set (string): imageset to use (eg: 'training', 'validation', 'testing').
transform (callable, optional): transformation to perform on the
input image
target_transform (callable, optional): transformation to perform on the
target image
dataset_name (string, optional): which dataset to load
(default: 'ADEChallengeData2016')
"""
def __init__(self, root, image_set, transform=None, target_transform=None,
dataset_name='ADEChallengeData2016'):
self.root = root
self.image_set = image_set
self.transform = transform
self.target_transform = target_transform
if image_set == 'training_sift':
image_name = 'train'
anno_folder = 'training_re'
elif image_set == 'validation_sift':
image_name = 'val'
anno_folder = 'validation_re'
else:
raise ValueError('image_set should be either of "training_sift", "validation_sift".')
self._annopath = os.path.join(
self.root, dataset_name, 'annotations', anno_folder, 'ADE_'+image_name+'_re_{:08d}.png')
self._imgpath = os.path.join(
self.root, dataset_name, 'images', image_set, 'ADE_'+image_name+'_{:08d}.sift')
self._imgsetpath = os.path.join(
self.root, dataset_name, 'objectInfo150.txt')
self.dataset_dir = os.path.join(self.root, dataset_name, 'annotations', anno_folder)
with open(self._imgsetpath) as f:
self.class_desc = [line.split('\t')[-1].strip('\n') for line in f.readlines()]
def __getitem__(self, index):
try:
img_id = index
# print('index', index)
target = np.array(Image.open(self._annopath.format(img_id)))
# print (target.shape)
features = np.zeros((49, 49, 130))
if (os.stat(self._imgpath.format(img_id)).st_size > 0): # file is not empty
data = np.array(pd.read_csv(self._imgpath.format(img_id), sep=' ', header=None))
for i in range(data.shape[0]):
features[int(data[i][0]-1)][int(data[i][1]-1)] = data[i][2:]
if self.transform is not None:
features = self.transform(features)
if self.target_transform is not None:
target = self.target_transform(target)
except: # if any error, return zero arrays of correct dimensions
index = -1
features = np.zeros((49, 49, 130))
target = np.zeros((196, 196))
if self.transform is not None:
features = self.transform(features)
if self.target_transform is not None:
target = self.target_transform(target)
return features, target
def __len__(self):
return len(os.listdir(self.dataset_dir))
And loaded the dataset using the inbuilt loader.
# Data augmentation and normalization for training
# Just normalization for validation
input_transforms = transforms.Compose([
transforms.Lambda(lambda img: torch.from_numpy(np.transpose(img, (2, 0, 1))))
])
target_transforms = transforms.Compose([
transforms.Lambda(lambda img: torch.from_numpy(np.array(img)))
])
root = '/home/shivam/Downloads/ADE20K/'
image_set = 'validation_sift'
dsets = ade20k_dataset.ADE20K_SIFT(root, image_set, transform=input_transforms,
target_transform=target_transforms,
dataset_name='ADEChallengeData2016')
dset_loaders = torch.utils.data.DataLoader(dsets, batch_size=2, shuffle=True, num_workers=1)
Now when I’m trying to retrieve (input, target) pair using:
for i, data in enumerate(dset_loaders, 0):
Pytorch gives the following error after fetching some inputs. (Note: It works smoothly for batchsize = 1)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-32-30856e4f68c8> in <module>()
27 # print(ind, inputs.size(), outputs.size())
28
---> 29 for i, data in enumerate(dset_loaders, 0):
30 print(i, data[0].size())
31
/home/shivam/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py in __next__(self)
172 self.reorder_dict[idx] = batch
173 continue
--> 174 return self._process_next_batch(batch)
175
176 next = __next__ # Python 2 compatibility
/home/shivam/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _process_next_batch(self, batch)
196 self._put_indices()
197 if isinstance(batch, ExceptionWrapper):
--> 198 raise batch.exc_type(batch.exc_msg)
199 return batch
200
TypeError: Traceback (most recent call last):
File "/home/shivam/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 34, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/home/shivam/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 79, in default_collate
return [default_collate(samples) for samples in transposed]
File "/home/shivam/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 79, in <listcomp>
return [default_collate(samples) for samples in transposed]
File "/home/shivam/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 66, in default_collate
return torch.stack(batch, 0)
File "/home/shivam/anaconda3/lib/python3.6/site-packages/torch/functional.py", line 56, in stack
return torch.cat(list(t.unsqueeze(dim) for t in sequence), dim)
TypeError: cat received an invalid combination of arguments - got (list, int), but expected one of:
* (sequence[torch.DoubleTensor] tensors)
* (sequence[torch.DoubleTensor] tensors, int dim)
I’m unable to understand this. Also what is the use of collate_fn, and should I write a custom collate_fn to solve this problem?