next(iter(DataLoader)) throws an error

I have images 128x128 and the corresponding labels are multi-element vectors of 128 elements.
I want to use DataLoader with a custom map-style dataset, which at the moment look like this:

# custom dataset
class MyDataset(Dataset):
    def __init__(self, images, labels=None, transforms=None):
        self.X = images
        self.y = labels
        self.transforms = transforms
         
    def __len__(self):
        return (len(self.X))
    
    def __getitem__(self, i):
        data = self.X.iloc[i, :]
        data = np.asarray(data).astype(np.float).reshape(1,128, 128)
        
        if self.transforms:
            data = self.transforms(data)
            
        if self.y is not None:
            return (data, self.y[i,:])
        else:
            return data

Now I have 20 images and 20 labels.
images.shape = (20, 16384)
labels.shape = (20, 128).

The third line among the ones below gives an error.

train_data = MyDataset(images, labels, None)
trainLoader = DataLoader(train_data, batch_size=len(train_data), num_workers=0)
data = next(iter(trainLoader))

And here is the error:

TypeError                                 Traceback (most recent call last)
<ipython-input-90-d9957599c9b0> in <module>
     trainLoader = DataLoader(train_data, batch_size=len(train_data), num_workers=0)
--->data = next(iter(trainLoader))

~\anaconda3\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
    343 
    344     def __next__(self):
--> 345         data = self._next_data()
    346         self._num_yielded += 1
    347         if self._dataset_kind == _DatasetKind.Iterable and \

~\anaconda3\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
    383     def _next_data(self):
    384         index = self._next_index()  # may raise StopIteration
--> 385         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    386         if self._pin_memory:
    387             data = _utils.pin_memory.pin_memory(data)

~\anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

~\anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

<ipython-input-82-d83f76ddf8f5> in __getitem__(self, i)
     17 
     18         if self.y is not None:
---> 19             return (data, self.y[i,:])
     20         else:
     21             return data

~\anaconda3\lib\site-packages\pandas\core\frame.py in __getitem__(self, key)
   2798             if self.columns.nlevels > 1:
   2799                 return self._getitem_multilevel(key)
-> 2800             indexer = self.columns.get_loc(key)
   2801             if is_integer(indexer):
   2802                 indexer = [indexer]

~\anaconda3\lib\site-packages\pandas\core\indexes\base.py in get_loc(self, key, method, tolerance)
   2644                 )
   2645             try:
-> 2646                 return self._engine.get_loc(key)
   2647             except KeyError:
   2648                 return self._engine.get_loc(self._maybe_cast_indexer(key))

pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()

TypeError: '(0, slice(None, None, None))' is an invalid key

Could you please point out what I am doing wrong?

I think the problem comes from the fact that y is a pandas dataframe and so you can’t just index directly, I think you need to do y.iloc[i,:] or something similar.

Thanks heaps for looking into the code! It was indeed about figuring out the right format for the labels y.