[solved] Error while loading dataset with custom Dataset class with batchsize > 1

I created a custom Dataset class for loading 3d float arrays as input and .png segmentation map for targets:

class ADE20K_SIFT(data.Dataset):
input is a 3d matrix and target is a .png image
    root (string): filepath to ADE20K root folder.
    image_set (string): imageset to use (eg: 'training', 'validation', 'testing').
    transform (callable, optional): transformation to perform on the
        input image
    target_transform (callable, optional): transformation to perform on the
        target image
    dataset_name (string, optional): which dataset to load
        (default: 'ADEChallengeData2016')

def __init__(self, root, image_set, transform=None, target_transform=None,
    self.root = root
    self.image_set = image_set
    self.transform = transform
    self.target_transform = target_transform

    if image_set == 'training_sift':
        image_name = 'train'
        anno_folder = 'training_re'
    elif image_set == 'validation_sift':
        image_name = 'val'            
        anno_folder = 'validation_re'
        raise ValueError('image_set should be either of "training_sift", "validation_sift".')

    self._annopath = os.path.join(
        self.root, dataset_name, 'annotations', anno_folder, 'ADE_'+image_name+'_re_{:08d}.png')
    self._imgpath = os.path.join(
        self.root, dataset_name, 'images', image_set, 'ADE_'+image_name+'_{:08d}.sift')
    self._imgsetpath = os.path.join(
        self.root, dataset_name, 'objectInfo150.txt')
    self.dataset_dir = os.path.join(self.root, dataset_name, 'annotations', anno_folder)
    with open(self._imgsetpath) as f:
        self.class_desc = [line.split('\t')[-1].strip('\n') for line in f.readlines()]
def __getitem__(self, index):

        img_id = index
        # print('index', index)
        target = np.array(Image.open(self._annopath.format(img_id)))
        # print (target.shape)

        features = np.zeros((49, 49, 130))
        if (os.stat(self._imgpath.format(img_id)).st_size > 0): # file is not empty
            data = np.array(pd.read_csv(self._imgpath.format(img_id), sep=' ', header=None))
            for i in range(data.shape[0]):
                features[int(data[i][0]-1)][int(data[i][1]-1)] = data[i][2:]

        if self.transform is not None:
            features = self.transform(features)

        if self.target_transform is not None:
            target = self.target_transform(target)
    except: # if any error, return zero arrays of correct dimensions
        index = -1
        features = np.zeros((49, 49, 130))
        target = np.zeros((196, 196))
        if self.transform is not None:
            features = self.transform(features)

        if self.target_transform is not None:
            target = self.target_transform(target)

    return features, target

def __len__(self):
    return len(os.listdir(self.dataset_dir))

And loaded the dataset using the inbuilt loader.

# Data augmentation and normalization for training
# Just normalization for validation
input_transforms = transforms.Compose([
    transforms.Lambda(lambda img: torch.from_numpy(np.transpose(img, (2, 0, 1))))
target_transforms = transforms.Compose([
    transforms.Lambda(lambda img: torch.from_numpy(np.array(img)))

root = '/home/shivam/Downloads/ADE20K/'
image_set = 'validation_sift'
dsets = ade20k_dataset.ADE20K_SIFT(root, image_set, transform=input_transforms, 
dset_loaders = torch.utils.data.DataLoader(dsets, batch_size=2, shuffle=True, num_workers=1)

Now when I’m trying to retrieve (input, target) pair using:
for i, data in enumerate(dset_loaders, 0):

Pytorch gives the following error after fetching some inputs. (Note: It works smoothly for batchsize = 1)

TypeError                                 Traceback (most recent call last)
<ipython-input-32-30856e4f68c8> in <module>()
     27 #     print(ind, inputs.size(), outputs.size())
---> 29 for i, data in enumerate(dset_loaders, 0):
     30     print(i, data[0].size())

/home/shivam/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py in __next__(self)
    172                 self.reorder_dict[idx] = batch
    173                 continue
--> 174             return self._process_next_batch(batch)
    176     next = __next__  # Python 2 compatibility

/home/shivam/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _process_next_batch(self, batch)
    196         self._put_indices()
    197         if isinstance(batch, ExceptionWrapper):
--> 198             raise batch.exc_type(batch.exc_msg)
    199         return batch

TypeError: Traceback (most recent call last):
  File "/home/shivam/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 34, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/shivam/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 79, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/home/shivam/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 79, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/home/shivam/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 66, in default_collate
    return torch.stack(batch, 0)
  File "/home/shivam/anaconda3/lib/python3.6/site-packages/torch/functional.py", line 56, in stack
    return torch.cat(list(t.unsqueeze(dim) for t in sequence), dim)
TypeError: cat received an invalid combination of arguments - got (list, int), but expected one of:
 * (sequence[torch.DoubleTensor] tensors)
 * (sequence[torch.DoubleTensor] tensors, int dim)

I’m unable to understand this. Also what is the use of collate_fn, and should I write a custom collate_fn to solve this problem?

Ok I think I solved it. In the except: clause of the getitem() function, I replaced
target = np.zeros((196, 196))
target = np.array(np.zeros((196, 196)), dtype='uint')

This solved the error. Though I’m not very sure how, maybe because earlier torch.cat() was trying to concatenate doubleTensor (from except:) with ByteTensor.