Torchmeta dataloader for Pascal5i

I recently started working on Torchmeta, pytorch framework for meta learning. I am using torchmeta to create the dataloaders for Pascal5i dataset (taken from here) which a standard dataset for semantic segmentation. The code works fine any other dataset (tested on Omnigalot, MiniImageNet, and CIFARFS) but Pascal5i.

The code and error are below:

from torchmeta.datasets import Pascal5i
from torchmeta.utils.data import BatchMetaDataLoader
import config

def load_trainloader():
     trainset = Pascal5i("data", num_classes_per_task=config.n, meta_train=True, download=True)
     trainloader = BatchMetDataLoader(trainset, batch_size=config.batch_size, shuffle=True, num_workers=0)

    return trainset, trainloader

_, trainloader = load_trainloader()
batch = next(iter(trainloader))

error:

--------------------------------------------------------------------------
TypeError                                Traceback (most recent call last)
<ipython-input-8-e4dc503de88f> in <module>
----> 1 batch = next(iter(trainloader))

~/venv/torch-py3/lib/python3.6/site-packages/torch/utils/data/dataloader.py in __next__(self)
    343 
    344     def __next__(self):
--> 345         data = self._next_data()
    346         self._num_yielded += 1
    347         if self._dataset_kind == _DatasetKind.Iterable and \

~/venv/torch-py3/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    383     def _next_data(self):
    384         index = self._next_index()  # may raise StopIteration
--> 385         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    386         if self._pin_memory:
    387             data = _utils.pin_memory.pin_memory(data)

~/venv/torch-py3/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

~/venv/torch-py3/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

~/venv/torch-py3/lib/python3.6/site-packages/torchmeta/utils/data/dataset.py in __getitem__(self, index)
    274                 self.num_classes_per_task - 1, index))
    275         assert len(index) == self.num_classes_per_task
--> 276         datasets = [self.dataset[i] for i in index]
    277         # Use deepcopy on `Categorical` target transforms, to avoid any side
    278         # effect across tasks.

~/venv/torch-py3/lib/python3.6/site-packages/torchmeta/utils/data/dataset.py in <listcomp>(.0)
    274                 self.num_classes_per_task - 1, index))
    275         assert len(index) == self.num_classes_per_task
--> 276         datasets = [self.dataset[i] for i in index]
    277         # Use deepcopy on `Categorical` target transforms, to avoid any side
    278         # effect across tasks.

~/venv/torch-py3/lib/python3.6/site-packages/torchmeta/datasets/pascal5i.py in __getitem__(self, index)
    146 
    147         return PascalDataset((data, masks), class_id, transform=transform,
--> 148             target_transform=target_transform)
    149 
    150     @property

~/venv/torch-py3/lib/python3.6/site-packages/torchmeta/datasets/pascal5i.py in __init__(self, data, class_id, transform, target_transform)
    245     def __init__(self, data, class_id, transform=None, target_transform=None):
    246         super(PascalDataset, self).__init__(transform=transform,
--> 247             target_transform=target_transform)
    248         self.data, self.masks = data
    249         self.class_id = class_id

TypeError: __init__() missing 1 required positional argument: 'index'

The line of code which raises the error tries to call the __init__ method from its parent class (torchmeta.utils.data.Dataset), which is defined here as torchmeta.utils.data.task.Dataset.
Looking at the __init__ method of this class in this line of code, it seems that the index argument is required and thus the error is raised.

Iā€™m not familiar with the code base, but it looks like a bug in the dataset implementation, so I would recommend to create an issue in their repository.

1 Like