I have images 128x128 and the corresponding labels are multi-element vectors of 128 elements.
I want to use DataLoader with a custom map-style dataset, which at the moment look like this:
# custom dataset
class MyDataset(Dataset):
def __init__(self, images, labels=None, transforms=None):
self.X = images
self.y = labels
self.transforms = transforms
def __len__(self):
return (len(self.X))
def __getitem__(self, i):
data = self.X.iloc[i, :]
data = np.asarray(data).astype(np.float).reshape(1,128, 128)
if self.transforms:
data = self.transforms(data)
if self.y is not None:
return (data, self.y[i,:])
else:
return data
Now I have 20 images and 20 labels.
images.shape = (20, 16384)
labels.shape = (20, 128).
The third line among the ones below gives an error.
train_data = MyDataset(images, labels, None)
trainLoader = DataLoader(train_data, batch_size=len(train_data), num_workers=0)
data = next(iter(trainLoader))
And here is the error:
TypeError Traceback (most recent call last)
<ipython-input-90-d9957599c9b0> in <module>
trainLoader = DataLoader(train_data, batch_size=len(train_data), num_workers=0)
--->data = next(iter(trainLoader))
~\anaconda3\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
343
344 def __next__(self):
--> 345 data = self._next_data()
346 self._num_yielded += 1
347 if self._dataset_kind == _DatasetKind.Iterable and \
~\anaconda3\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
383 def _next_data(self):
384 index = self._next_index() # may raise StopIteration
--> 385 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
386 if self._pin_memory:
387 data = _utils.pin_memory.pin_memory(data)
~\anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
~\anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py in <listcomp>(.0)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
<ipython-input-82-d83f76ddf8f5> in __getitem__(self, i)
17
18 if self.y is not None:
---> 19 return (data, self.y[i,:])
20 else:
21 return data
~\anaconda3\lib\site-packages\pandas\core\frame.py in __getitem__(self, key)
2798 if self.columns.nlevels > 1:
2799 return self._getitem_multilevel(key)
-> 2800 indexer = self.columns.get_loc(key)
2801 if is_integer(indexer):
2802 indexer = [indexer]
~\anaconda3\lib\site-packages\pandas\core\indexes\base.py in get_loc(self, key, method, tolerance)
2644 )
2645 try:
-> 2646 return self._engine.get_loc(key)
2647 except KeyError:
2648 return self._engine.get_loc(self._maybe_cast_indexer(key))
pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()
TypeError: '(0, slice(None, None, None))' is an invalid key
Could you please point out what I am doing wrong?